In [1]:

# # Task 4: Therapeutic Target Prioritization
# 
# **Objective**: Identify and rank genes with the highest therapeutic potential for ALS
# 
# ## Biological Rationale
# 
# This task integrates multiple metrics to identify genes that, when perturbed, can effectively 
# shift ALS motor neurons toward a healthier (PN-like) state. We optimize perturbation strength 
# for each gene to find the "therapeutic window" - strong enough to rescue disease phenotype 
# but not so strong as to disrupt cellular identity.
# 
# ## Prioritization Criteria
# 
# 1. **Disease Rescue (40% weight)**: Movement toward healthy PN state
#    - *Why primary?* Core therapeutic objective is to reverse disease phenotype
# 
# 2. **Coverage (30% weight)**: Percentage of cells that improve
#    - *Why important?* Treatment must benefit majority of affected cells, not just outliers
# 
# 3. **Effect Size (15% weight)**: Magnitude of perturbation effect
#    - *Why moderate?* Need sufficient strength (>3) but avoid excessive disruption (>8)
# 
# 4. **Structure Preservation (15% weight)**: Retention of cell identity
#    - *Why included?* Perturbations must maintain motor neuron characteristics
# 
# ## Output
# 
# Ranked list of therapeutic candidates with optimal perturbation strength (factor) for each gene

## 1. Setup and Configuration

In [2]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
import yaml
from scipy.spatial.distance import euclidean

# Path handling
NOTEBOOK_DIR = Path.cwd()
PROJECT_ROOT = NOTEBOOK_DIR.parent if NOTEBOOK_DIR.name == 'notebooks' else NOTEBOOK_DIR

import sys
if str(PROJECT_ROOT) not in sys.path:
    sys.path.insert(0, str(PROJECT_ROOT))

from utils.data_io import DataIOManager

print(f"Project root: {PROJECT_ROOT}")

Project root: /Users/lubainakothari/Desktop/perturbation_newstructure


In [3]:
# Load configuration
config_path = PROJECT_ROOT / "config" / "config.yaml"
with open(config_path, 'r') as f:
    config = yaml.safe_load(f)

RANDOM_SEED = config['random_seed']
np.random.seed(RANDOM_SEED)

# Extract perturbation factors tested
KD_FACTORS = config['perturbation']['knock_down']['factors']
KU_FACTORS = config['perturbation']['knock_up']['factors']

print(f"Perturbation factors tested:")
print(f"  Knockdown: {KD_FACTORS}")
print(f"  Knockup: {KU_FACTORS}")

# Define paths
CACHE_DIR = PROJECT_ROOT / config['data']['cache_dir']
RESULTS_DIR = PROJECT_ROOT / config['data']['results_dir']
TASK3_DIR = RESULTS_DIR / "task3"
TASK4_DIR = RESULTS_DIR / "task4"
TABLES_DIR = TASK4_DIR / "tables"
FIGURES_DIR = TASK4_DIR / "figures"

# Create output directories
for d in [TASK4_DIR, TABLES_DIR, FIGURES_DIR]:
    d.mkdir(parents=True, exist_ok=True)

# Visualization settings
plt.rcParams['figure.dpi'] = config['visualization']['figure_dpi']
sns.set_style("whitegrid")

print(f"Output directory: {TASK4_DIR}")

Perturbation factors tested:
  Knockdown: [0.2, 0.5]
  Knockup: [2.0, 3.0]
Output directory: /Users/lubainakothari/Desktop/perturbation_newstructure/results/task4


## 2. Load Task 3 Results

We build on the embedding analysis from Task 3, which computed:
- Centroid shifts (population-level movement)
- Cosine shifts (directional changes)
- Neighborhood preservation (structural integrity)
- Disease rescue scores (ALS → PN distance)

In [4]:
print("\nLoading Task 3 analyses...")

# Load comprehensive results
full_analysis = pd.read_csv(TASK3_DIR / "tables" / "full_analysis.csv")
rescue_analysis = pd.read_csv(TASK3_DIR / "tables" / "disease_rescue_analysis.csv")

# Load embeddings for coverage analysis
io_manager = DataIOManager(base_dir=str(PROJECT_ROOT / "data"), cache_dir=str(CACHE_DIR))
embeddings_dict = io_manager.load_embeddings_hdf5("task2_all_embeddings.h5")
baseline_embeddings = embeddings_dict.pop('baseline')
perturbation_embeddings = embeddings_dict

# Load metadata
metadata = pd.read_csv(RESULTS_DIR / "task2" / "tables" / "perturbation_metadata.csv")

print(f"Loaded:")
print(f"  Full analysis: {len(full_analysis)} perturbations")
print(f"  Rescue analysis: {len(rescue_analysis)} ALS perturbations")
print(f"  Embeddings: {len(perturbation_embeddings)}")
print(f"  Baseline cells: {baseline_embeddings.shape[0]:,}")


Loading Task 3 analyses...
DataIOManager initialized
  Data directory: /Users/lubainakothari/Desktop/perturbation_newstructure/data
  Cache directory: /Users/lubainakothari/Desktop/perturbation_newstructure/cache
Loading embeddings from: /Users/lubainakothari/Desktop/perturbation_newstructure/cache/task2_all_embeddings.h5


Loading: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████| 27/27 [00:00<00:00, 413.98it/s]

✓ Loaded 27 embedding sets
Loaded:
  Full analysis: 24 perturbations
  Rescue analysis: 12 ALS perturbations
  Embeddings: 26
  Baseline cells: 50





## 3. Compute Coverage Metric

**Coverage** = Percentage of individual cells that move closer to healthy (PN) reference

**Why this matters**: A perturbation with high rescue score (centroid movement) but low 
coverage might only benefit a small subset of cells. We want treatments that help the 
majority of affected motor neurons, not just a few outliers.

**Implementation**: For each cell, compare its distance to PN reference before vs. after 
perturbation. Count how many cells get closer.

In [5]:
print("\nComputing cell-level coverage...")

# Compute PN reference from healthy state perturbations
pn_perts = metadata[metadata['condition'] == 'PN']
pn_embeddings = [perturbation_embeddings[pid] for pid in pn_perts['perturbation_id'].values]
pn_reference = np.mean(np.vstack(pn_embeddings), axis=0)

print(f"PN reference: averaged {len(pn_embeddings)} PN perturbation centroids")

def compute_coverage(baseline_emb, perturbed_emb, reference):
    """
    Compute percentage of cells that moved closer to reference (healthy) state.
    
    Parameters:
        baseline_emb: Baseline cell embeddings
        perturbed_emb: Perturbed cell embeddings
        reference: Healthy state reference centroid
    
    Returns:
        pct_improved: Percentage of cells closer to reference after perturbation
        improved_cells: Boolean array of which cells improved
    """
    n_cells = len(baseline_emb)
    
    # Calculate per-cell distances to healthy reference
    baseline_dists = np.array([euclidean(cell, reference) for cell in baseline_emb])
    perturbed_dists = np.array([euclidean(cell, reference) for cell in perturbed_emb])
    
    # Identify cells that moved closer
    improved = perturbed_dists < baseline_dists
    pct_improved = (improved.sum() / n_cells) * 100
    
    return pct_improved, improved

# Analyze coverage for all ALS perturbations
coverage_results = []
als_perts = metadata[metadata['condition'] == 'ALS']

print(f"\nAnalyzing coverage for {len(als_perts)} ALS perturbations...")

for i, (_, row) in enumerate(als_perts.iterrows(), 1):
    pert_id = row['perturbation_id']
    
    if pert_id not in perturbation_embeddings:
        continue
    
    pct_improved, _ = compute_coverage(
        baseline_embeddings,
        perturbation_embeddings[pert_id],
        pn_reference
    )
    
    coverage_results.append({
        'perturbation_id': pert_id,
        'gene': row['gene'],
        'type': row['type'],
        'factor': row['factor'],
        'pct_cells_improved': pct_improved
    })
    
    if i % 5 == 0 or i == len(als_perts):
        print(f"  [{i}/{len(als_perts)}] Processed...")

coverage_df = pd.DataFrame(coverage_results)

print(f"\nCoverage statistics:")
print(f"  Mean: {coverage_df['pct_cells_improved'].mean():.1f}%")
print(f"  Median: {coverage_df['pct_cells_improved'].median():.1f}%")
print(f"  Range: {coverage_df['pct_cells_improved'].min():.1f}% - {coverage_df['pct_cells_improved'].max():.1f}%")

# Save results
coverage_df.to_csv(TABLES_DIR / "coverage_analysis.csv", index=False)
print(f"\nSaved: coverage_analysis.csv")


Computing cell-level coverage...
PN reference: averaged 12 PN perturbation centroids

Analyzing coverage for 12 ALS perturbations...
  [5/12] Processed...
  [10/12] Processed...
  [12/12] Processed...

Coverage statistics:
  Mean: 21.3%
  Median: 22.0%
  Range: 10.0% - 28.0%

Saved: coverage_analysis.csv


## 4. Integrate All Metrics

Combine rescue, coverage, effect size, and preservation into single dataset

In [6]:
print("\nMerging all metrics...")

# Start with rescue data (ALS perturbations only)
als_full = rescue_analysis.merge(
    coverage_df[['perturbation_id', 'pct_cells_improved']], 
    on='perturbation_id', 
    how='left'
)

# Add embedding metrics from Task 3
als_full = als_full.merge(
    full_analysis[['perturbation_id', 'centroid_shift', 'cosine_shift', 
                   'mean_cell_shift', 'neighborhood_preservation']], 
    on='perturbation_id', 
    how='left'
)

print(f"Integrated dataset: {len(als_full)} ALS perturbations")
print(f"\nMetrics available:")
for col in ['rescue_score', 'pct_cells_improved','centroid_shift', 'neighborhood_preservation']:
    print(f"  - {col}")

# Save integrated data
als_full.to_csv(TABLES_DIR / "als_perturbations_all_metrics.csv", index=False)


Merging all metrics...
Integrated dataset: 12 ALS perturbations

Metrics available:
  - rescue_score
  - pct_cells_improved
  - centroid_shift
  - neighborhood_preservation


## 5. Optimize Perturbation Factors

**Goal**: For each gene, identify the optimal perturbation strength (factor)

**Method**: 
- Compute composite score balancing all criteria
- Compare all tested factors for each gene
- Select factor with highest composite score

**Composite Score Formula**:
- 40% Disease rescue (primary objective)
- 30% Coverage (reliability across cells)
- 15% Effect size (appropriate strength)
- 15% Structure preservation (maintains identity)

In [7]:
print("\nOptimizing perturbation factors for each gene...")

def compute_composite_score(row):
    """
    Calculate composite therapeutic score from multiple criteria.
    
    Scoring logic:
    - Rescue: More negative is better (moving toward healthy)
    - Coverage: Higher percentage is better
    - Effect: Sweet spot of 3-8 (sufficient but not excessive)
    - Preservation: Higher is better (maintains cell identity)
    
    Returns: Composite score (0-1, higher is better)
    """
    # 1. Disease rescue (normalize to 0-1, cap at 5)
    rescue_norm = max(0, -row['rescue_score']) / 5.0
    rescue_norm = min(rescue_norm, 1.0)
    
    # 2. Coverage (convert percentage to 0-1)
    coverage_norm = row['pct_cells_improved'] / 100.0
    
    # 3. Effect size (optimal range: 3-8)
    effect = row['centroid_shift']
    if 3.0 <= effect <= 8.0:
        effect_norm = 1.0  # Ideal range
    elif effect < 3.0:
        effect_norm = effect / 3.0  # Too weak
    else:
        effect_norm = max(0, 1.0 - (effect - 8.0) / 10.0)  # Too strong
    
    # 4. Structure preservation (already 0-1)
    preservation_norm = row['neighborhood_preservation']
    
    # Weighted combination
    composite = (
        0.40 * rescue_norm +       # Disease rescue (primary)
        0.30 * coverage_norm +     # Cell coverage (reliability)
        0.15 * effect_norm +       # Effect size (potency)
        0.15 * preservation_norm   # Preservation (identity)
    )
    
    return composite

# Compute scores for all perturbations
als_full['composite_score'] = als_full.apply(compute_composite_score, axis=1)

# Find optimal perturbation for each gene
optimal_perturbations = []
genes = als_full['gene'].unique()

print(f"\nFinding optimal factor for {len(genes)} genes...")

for gene in genes:
    gene_data = als_full[als_full['gene'] == gene].copy()
    
    if len(gene_data) == 0:
        continue
    
    # Select best perturbation by composite score
    best_idx = gene_data['composite_score'].idxmax()
    best_pert = gene_data.loc[best_idx]
    
    # Record all tested factors for reference
    all_perts = gene_data.sort_values('composite_score', ascending=False)
    
    optimal_perturbations.append({
        'gene': gene,
        'optimal_perturbation_id': best_pert['perturbation_id'],
        'optimal_type': best_pert['type'],
        'optimal_factor': best_pert['factor'],
        'composite_score': best_pert['composite_score'],
        'rescue_score': best_pert['rescue_score'],
        'pct_cells_improved': best_pert['pct_cells_improved'],
        'centroid_shift': best_pert['centroid_shift'],
        'neighborhood_preservation': best_pert['neighborhood_preservation'],
        'moves_toward_healthy': best_pert['moves_toward_healthy'],
        'n_perturbations_tested': len(gene_data),
        'all_factors_tested': ','.join([f"{p['type']}_{p['factor']}" 
                                        for _, p in all_perts.iterrows()])
    })

optimal_df = pd.DataFrame(optimal_perturbations)
optimal_df = optimal_df.sort_values('composite_score', ascending=False)

print(f"\nOptimization complete: {len(optimal_df)} genes")
print(f"\nOptimal factor distribution:")
for factor, count in optimal_df['optimal_factor'].value_counts().sort_index().items():
    print(f"  {factor}x: {count} genes")

# Save results
optimal_df.to_csv(TABLES_DIR / "optimal_therapeutic_targets.csv", index=False)


Optimizing perturbation factors for each gene...

Finding optimal factor for 3 genes...

Optimization complete: 3 genes

Optimal factor distribution:
  0.5x: 3 genes


## 6. Top Therapeutic Candidates

Display top 10 genes ranked by composite score

In [8]:
print("\n" + "="*70)
print("TOP 10 THERAPEUTIC CANDIDATES")
print("="*70 + "\n")

for i, (_, row) in enumerate(optimal_df.head(10).iterrows(), 1):
    print(f"{i:2d}. {row['gene']:10s} | {row['optimal_type']:10s} @ {row['optimal_factor']}x")
    print(f"    Composite: {row['composite_score']:.3f}")
    print(f"    ├─ Rescue: {row['rescue_score']:7.4f} "
          f"({'✓ toward PN' if row['moves_toward_healthy'] else '✗ away from PN'})")
    print(f"    ├─ Coverage: {row['pct_cells_improved']:5.1f}% cells improved")
    print(f"    ├─ Effect: {row['centroid_shift']:6.3f} (centroid shift)")
    print(f"    └─ Preservation: {row['neighborhood_preservation']:.3f}")
    print(f"    Tested: {row['n_perturbations_tested']} factors "
          f"[{row['all_factors_tested']}]")
    print()


TOP 10 THERAPEUTIC CANDIDATES

 1. DMD        | knock_down @ 0.5x
    Composite: 0.213
    ├─ Rescue:  0.0004 (✗ away from PN)
    ├─ Coverage:  28.0% cells improved
    ├─ Effect:  0.018 (centroid shift)
    └─ Preservation: 0.854
    Tested: 4 factors [knock_down_0.5,knock_up_2.0,knock_up_3.0,knock_down_0.2]

 2. MAP1B      | knock_down @ 0.5x
    Composite: 0.207
    ├─ Rescue: -0.0002 (✓ toward PN)
    ├─ Coverage:  26.0% cells improved
    ├─ Effect:  0.027 (centroid shift)
    └─ Preservation: 0.854
    Tested: 4 factors [knock_down_0.5,knock_down_0.2,knock_up_3.0,knock_up_2.0]

 3. KHDRBS2    | knock_down @ 0.5x
    Composite: 0.190
    ├─ Rescue: -0.0029 (✓ toward PN)
    ├─ Coverage:  18.0% cells improved
    ├─ Effect:  0.012 (centroid shift)
    └─ Preservation: 0.900
    Tested: 4 factors [knock_down_0.5,knock_down_0.2,knock_up_2.0,knock_up_3.0]



## 7. Generate Biological Rationale

Create interpretable descriptions of why each gene is a therapeutic candidate

In [9]:
def generate_therapeutic_rationale(row):
    """
    Generate biological rationale for therapeutic candidacy.
    
    Considers:
    - Perturbation direction and strength
    - Disease rescue magnitude
    - Coverage across cells
    - Preservation of cell identity
    
    Returns: Human-readable rationale string
    """
    gene = row['gene']
    pert_type = row['optimal_type']
    factor = row['optimal_factor']
    rescue = row['rescue_score']
    coverage = row['pct_cells_improved']
    preservation = row['neighborhood_preservation']
    
    # Describe perturbation
    if pert_type == 'knock_down':
        action = f"Knockdown to {factor*100:.0f}% expression"
    else:
        action = f"Upregulation by {factor}x"
    
    # Build rationale components
    components = []
    
    # Rescue effect
    if rescue < -0.5:
        components.append(f"strong disease rescue (Δ={abs(rescue):.2f} toward healthy)")
    elif rescue < -0.2:
        components.append(f"moderate disease rescue (Δ={abs(rescue):.2f})")
    else:
        components.append(f"mild rescue effect (Δ={abs(rescue):.2f})")
    
    # Coverage
    if coverage > 80:
        components.append(f"benefits most cells ({coverage:.0f}%)")
    elif coverage > 60:
        components.append(f"moderate coverage ({coverage:.0f}%)")
    else:
        components.append(f"limited coverage ({coverage:.0f}%)")
    
    # Preservation
    if preservation > 0.85:
        components.append("excellent identity preservation")
    elif preservation > 0.70:
        components.append("good structural integrity")
    else:
        components.append("some structural disruption")
    
    return f"{action}: {', '.join(components)}."

# Generate rationales
optimal_df['rationale'] = optimal_df.apply(generate_therapeutic_rationale, axis=1)

print("\n" + "="*70)
print("THERAPEUTIC RATIONALE (TOP 5)")
print("="*70 + "\n")

for i, (_, row) in enumerate(optimal_df.head(5).iterrows(), 1):
    print(f"{i}. {row['gene'].upper()}")
    print(f"   {row['rationale']}\n")

# Save with rationales
optimal_df.to_csv(TABLES_DIR / "therapeutic_targets_with_rationale.csv", index=False)


THERAPEUTIC RATIONALE (TOP 5)

1. DMD
   Knockdown to 50% expression: mild rescue effect (Δ=0.00), limited coverage (28%), excellent identity preservation.

2. MAP1B
   Knockdown to 50% expression: mild rescue effect (Δ=0.00), limited coverage (26%), excellent identity preservation.

3. KHDRBS2
   Knockdown to 50% expression: mild rescue effect (Δ=0.00), limited coverage (18%), excellent identity preservation.



## 8. Visualizations

In [10]:
print("\n" + "="*70)
print("GENERATING VISUALIZATIONS")
print("="*70)


GENERATING VISUALIZATIONS


In [11]:
# Visualization 1: Therapeutic Ranking Overview (4-panel)
print("\n1. Therapeutic target ranking overview...")

fig, axes = plt.subplots(2, 2, figsize=(18, 14))

# Panel 1: Composite scores (top 15)
top_15 = optimal_df.head(15)
y_pos = np.arange(len(top_15))
colors = ['green' if x else 'orange' for x in top_15['moves_toward_healthy']]

axes[0, 0].barh(y_pos, top_15['composite_score'], 
               color=colors, alpha=0.7, edgecolor='black', linewidth=1.2)
axes[0, 0].set_yticks(y_pos)
axes[0, 0].set_yticklabels([f"{r['gene']} ({r['optimal_type'][:2].upper()}@{r['optimal_factor']})" 
                            for _, r in top_15.iterrows()], fontsize=10)
axes[0, 0].set_xlabel('Composite Therapeutic Score', fontsize=12, fontweight='bold')
axes[0, 0].set_title('Top 15 Therapeutic Candidates', fontsize=14, fontweight='bold', pad=10)
axes[0, 0].invert_yaxis()
axes[0, 0].grid(axis='x', alpha=0.3)

# Panel 2: Score components for top gene
top_gene = optimal_df.iloc[0]
components = ['Rescue\nPotential', 'Coverage\n(% cells)', 'Effect\nSize', 'Preservation']
component_values = [
    min(max(0, -top_gene['rescue_score']) / 5.0, 1.0),
    top_gene['pct_cells_improved'] / 100.0,
    min(top_gene['centroid_shift'] / 8.0, 1.0) if top_gene['centroid_shift'] <= 8 
        else max(0, 1.0 - (top_gene['centroid_shift'] - 8.0) / 10.0),
    top_gene['neighborhood_preservation']
]

bars = axes[0, 1].bar(components, component_values, 
                     color=['#e74c3c', '#3498db', '#f39c12', '#9b59b6'],
                     alpha=0.7, edgecolor='black', linewidth=1.2)
axes[0, 1].set_ylim(0, 1.0)
axes[0, 1].set_ylabel('Normalized Score', fontsize=12, fontweight='bold')
axes[0, 1].set_title(f'Score Components: {top_gene["gene"]} '
                    f'({top_gene["optimal_type"][:2].upper()}@{top_gene["optimal_factor"]})', 
                    fontsize=14, fontweight='bold', pad=10)
axes[0, 1].grid(axis='y', alpha=0.3)

for bar, val in zip(bars, component_values):
    axes[0, 1].text(bar.get_x() + bar.get_width()/2., val + 0.03,
                   f'{val:.2f}', ha='center', va='bottom', 
                   fontsize=11, fontweight='bold')

# Panel 3: Rescue vs Coverage
scatter = axes[1, 0].scatter(-optimal_df['rescue_score'],
                            optimal_df['pct_cells_improved'],
                            c=optimal_df['composite_score'],
                            cmap='RdYlGn', s=180, alpha=0.7,
                            edgecolors='black', linewidth=1.2)
axes[1, 0].set_xlabel('Rescue Potential\n(distance moved toward healthy)', 
                     fontsize=12, fontweight='bold')
axes[1, 0].set_ylabel('Coverage (% Cells Improved)', fontsize=12, fontweight='bold')
axes[1, 0].set_title('Efficacy vs Reliability', fontsize=14, fontweight='bold', pad=10)
axes[1, 0].grid(alpha=0.3)
cbar1 = plt.colorbar(scatter, ax=axes[1, 0])
cbar1.set_label('Composite Score', fontweight='bold', fontsize=11)

# Annotate top 3
for _, row in optimal_df.head(3).iterrows():
    axes[1, 0].annotate(row['gene'], 
                       (-row['rescue_score'], row['pct_cells_improved']),
                       fontsize=10, fontweight='bold',
                       bbox=dict(boxstyle='round,pad=0.4', facecolor='yellow', alpha=0.6))

# Panel 4: Effect size vs Preservation
scatter2 = axes[1, 1].scatter(optimal_df['centroid_shift'],
                             optimal_df['neighborhood_preservation'],
                             c=optimal_df['composite_score'],
                             cmap='RdYlGn', s=180, alpha=0.7,
                             edgecolors='black', linewidth=1.2)
axes[1, 1].axvline(3.0, color='red', linestyle='--', alpha=0.4, 
                  linewidth=2, label='Min effective')
axes[1, 1].axvline(8.0, color='red', linestyle='--', alpha=0.4, 
                  linewidth=2, label='Max optimal')
axes[1, 1].set_xlabel('Effect Size (Centroid Shift)', fontsize=12, fontweight='bold')
axes[1, 1].set_ylabel('Structure Preservation', fontsize=12, fontweight='bold')
axes[1, 1].set_title('Potency vs Identity Retention', fontsize=14, fontweight='bold', pad=10)
axes[1, 1].legend(loc='lower left', fontsize=10, framealpha=0.9)
axes[1, 1].grid(alpha=0.3)
cbar2 = plt.colorbar(scatter2, ax=axes[1, 1])
cbar2.set_label('Composite Score', fontweight='bold', fontsize=11)

# Annotate top 3
for _, row in optimal_df.head(3).iterrows():
    axes[1, 1].annotate(row['gene'],
                       (row['centroid_shift'], row['neighborhood_preservation']),
                       fontsize=10, fontweight='bold',
                       bbox=dict(boxstyle='round,pad=0.4', facecolor='yellow', alpha=0.6))

plt.suptitle('Task 4: Therapeutic Target Prioritization', 
            fontsize=18, fontweight='bold', y=0.995)
plt.tight_layout()
plt.savefig(FIGURES_DIR / 'therapeutic_ranking.png', dpi=150, bbox_inches='tight')
plt.close()
print("   Saved: therapeutic_ranking.png")


1. Therapeutic target ranking overview...
   Saved: therapeutic_ranking.png


In [12]:
# Visualization 2: Factor Optimization Analysis
print("\n2. Factor optimization distribution...")

fig, axes = plt.subplots(1, 2, figsize=(16, 6))

# Panel 1: Distribution of optimal factors
factor_dist = optimal_df['optimal_factor'].value_counts().sort_index()
bars = axes[0].bar(range(len(factor_dist)), factor_dist.values, 
                  alpha=0.7, edgecolor='black', color='steelblue', linewidth=1.2)
axes[0].set_xticks(range(len(factor_dist)))
axes[0].set_xticklabels([f"{f}x" for f in factor_dist.index], fontsize=12)
axes[0].set_ylabel('Number of Genes', fontsize=13, fontweight='bold')
axes[0].set_xlabel('Optimal Perturbation Factor', fontsize=13, fontweight='bold')
axes[0].set_title('Distribution of Optimal Factors', fontsize=14, fontweight='bold', pad=10)
axes[0].grid(axis='y', alpha=0.3)

for bar, val in zip(bars, factor_dist.values):
    axes[0].text(bar.get_x() + bar.get_width()/2., val + 0.3,
                str(int(val)), ha='center', va='bottom', 
                fontsize=12, fontweight='bold')

# Panel 2: Perturbation type distribution
type_dist = optimal_df['optimal_type'].value_counts()
type_labels = ['Knockdown' if t == 'knock_down' else 'Knockup' for t in type_dist.index]
colors_type = ['#e74c3c' if t == 'knock_down' else '#2ecc71' for t in type_dist.index]
bars2 = axes[1].bar(range(len(type_dist)), type_dist.values, 
                   alpha=0.7, edgecolor='black', color=colors_type, linewidth=1.2)
axes[1].set_xticks(range(len(type_dist)))
axes[1].set_xticklabels(type_labels, fontsize=12)
axes[1].set_ylabel('Number of Genes', fontsize=13, fontweight='bold')
axes[1].set_xlabel('Optimal Perturbation Type', fontsize=13, fontweight='bold')
axes[1].set_title('Knockdown vs Knockup Preference', fontsize=14, fontweight='bold', pad=10)
axes[1].grid(axis='y', alpha=0.3)

for bar, val in zip(bars2, type_dist.values):
    axes[1].text(bar.get_x() + bar.get_width()/2., val + 0.3,
                str(int(val)), ha='center', va='bottom', 
                fontsize=12, fontweight='bold')

plt.tight_layout()
plt.savefig(FIGURES_DIR / 'factor_optimization.png', dpi=150, bbox_inches='tight')
plt.close()
print("   Saved: factor_optimization.png")


2. Factor optimization distribution...
   Saved: factor_optimization.png


In [13]:
# Visualization 3: Decision Matrix Heatmap
print("\n3. Decision matrix (top 20 genes)...")

fig, ax = plt.subplots(figsize=(12, 11))

# Select top 20 for clarity
top_20 = optimal_df.head(20)

# Normalize metrics to 0-1 for comparison
heatmap_data = pd.DataFrame({
    'Rescue': -top_20['rescue_score'] / top_20['rescue_score'].abs().max(),
    'Coverage': top_20['pct_cells_improved'] / 100.0,
    'Effect': top_20['centroid_shift'] / top_20['centroid_shift'].max(),
    'Preservation': top_20['neighborhood_preservation'],
    'COMPOSITE': top_20['composite_score']
}, index=[f"{r['gene']} ({r['optimal_type'][:2].upper()}@{r['optimal_factor']})" 
          for _, r in top_20.iterrows()])

# Plot heatmap
sns.heatmap(heatmap_data, annot=True, fmt='.2f', cmap='RdYlGn',
           cbar_kws={'label': 'Normalized Score'},
           linewidths=0.5, linecolor='gray', ax=ax,
           vmin=0, vmax=1)

ax.set_title('Therapeutic Decision Matrix (Top 20 Candidates)', 
            fontsize=15, fontweight='bold', pad=15)
ax.set_xlabel('Evaluation Criteria', fontsize=13, fontweight='bold')
ax.set_ylabel('Gene (Type@Factor)', fontsize=13, fontweight='bold')

plt.tight_layout()
plt.savefig(FIGURES_DIR / 'decision_matrix.png', dpi=150, bbox_inches='tight')
plt.close()
print("   Saved: decision_matrix.png")


3. Decision matrix (top 20 genes)...
   Saved: decision_matrix.png


In [14]:
# Visualization 4: Coverage Analysis
print("\n4. Coverage distribution analysis...")

fig, axes = plt.subplots(1, 2, figsize=(16, 6))

# Panel 1: Coverage histogram
axes[0].hist(optimal_df['pct_cells_improved'], bins=20, alpha=0.7, 
            edgecolor='black', color='steelblue', linewidth=1.2)
median_coverage = optimal_df['pct_cells_improved'].median()
axes[0].axvline(median_coverage, color='red', linestyle='--', linewidth=2.5,
               label=f'Median: {median_coverage:.1f}%')
axes[0].set_xlabel('Coverage (% Cells Improved)', fontsize=13, fontweight='bold')
axes[0].set_ylabel('Number of Genes', fontsize=13, fontweight='bold')
axes[0].set_title('Distribution of Cell Coverage', fontsize=14, fontweight='bold', pad=10)
axes[0].legend(fontsize=12, framealpha=0.9)
axes[0].grid(axis='y', alpha=0.3)

# Panel 2: Coverage by perturbation factor
coverage_by_factor = optimal_df.groupby('optimal_factor')['pct_cells_improved'].mean()
bars = axes[1].bar(range(len(coverage_by_factor)), coverage_by_factor.values,
                  alpha=0.7, edgecolor='black', color='lightcoral', linewidth=1.2)
axes[1].set_xticks(range(len(coverage_by_factor)))
axes[1].set_xticklabels([f"{f}x" for f in coverage_by_factor.index], fontsize=12)
axes[1].set_ylabel('Mean Coverage (%)', fontsize=13, fontweight='bold')
axes[1].set_xlabel('Perturbation Factor', fontsize=13, fontweight='bold')
axes[1].set_title('Coverage by Perturbation Strength', fontsize=14, fontweight='bold', pad=10)
axes[1].grid(axis='y', alpha=0.3)

for bar, val in zip(bars, coverage_by_factor.values):
    axes[1].text(bar.get_x() + bar.get_width()/2., val + 1,
                f'{val:.1f}%', ha='center', va='bottom', 
                fontsize=11, fontweight='bold')

plt.tight_layout()
plt.savefig(FIGURES_DIR / 'coverage_analysis.png', dpi=150, bbox_inches='tight')
plt.close()
print("   Saved: coverage_analysis.png")

print("\nAll visualizations complete")


4. Coverage distribution analysis...
   Saved: coverage_analysis.png

All visualizations complete


## 9. Summary Report

In [15]:
print("\n" + "="*70)
print("TASK 4 SUMMARY")
print("="*70)

# Summary statistics
print(f"\nAnalysis Overview:")
print(f"  Genes evaluated: {len(optimal_df)}")
print(f"  Disease-rescuing genes: {optimal_df['moves_toward_healthy'].sum()}")
print(f"  Mean coverage: {optimal_df['pct_cells_improved'].mean():.1f}%")
print(f"  Mean rescue score: {optimal_df['rescue_score'].mean():.3f}")

print(f"\n" + "="*70)
print("TOP 5 THERAPEUTIC CANDIDATES")
print("="*70)

for i, (_, row) in enumerate(optimal_df.head(5).iterrows(), 1):
    print(f"\n{i}. {row['gene'].upper()} - "
          f"{row['optimal_type'].replace('_', ' ').title()} @ {row['optimal_factor']}x")
    print(f"   Composite Score: {row['composite_score']:.3f}")
    print(f"   Metrics:")
    print(f"     • Rescue: {row['rescue_score']:.4f} "
          f"({'✓ rescues' if row['moves_toward_healthy'] else '✗ no rescue'})")
    print(f"     • Coverage: {row['pct_cells_improved']:.1f}% cells improved")
    print(f"     • Effect: {row['centroid_shift']:.3f}")
    print(f"     • Preservation: {row['neighborhood_preservation']:.3f}")
    print(f"   Rationale: {row['rationale']}")

print(f"\nOutputs:")
print(f"  {TABLES_DIR}/")
print(f"    - coverage_analysis.csv")
print(f"    - als_perturbations_all_metrics.csv")
print(f"    - optimal_therapeutic_targets.csv")
print(f"    - therapeutic_targets_with_rationale.csv")
print(f"  {FIGURES_DIR}/")
print(f"    - therapeutic_ranking.png")
print(f"    - factor_optimization.png")
print(f"    - decision_matrix.png")
print(f"    - coverage_analysis.png")

print(f"\nNext Steps:")
print(f"  1. Validate top candidates in experimental models")
print(f"  2. Investigate molecular mechanisms of rescue")
print(f"  3. Assess potential combination therapies")

print("\nTask 4 complete - Therapeutic targets prioritized with optimal dosing")


TASK 4 SUMMARY

Analysis Overview:
  Genes evaluated: 3
  Disease-rescuing genes: 2
  Mean coverage: 24.0%
  Mean rescue score: -0.001

TOP 5 THERAPEUTIC CANDIDATES

1. DMD - Knock Down @ 0.5x
   Composite Score: 0.213
   Metrics:
     • Rescue: 0.0004 (✗ no rescue)
     • Coverage: 28.0% cells improved
     • Effect: 0.018
     • Preservation: 0.854
   Rationale: Knockdown to 50% expression: mild rescue effect (Δ=0.00), limited coverage (28%), excellent identity preservation.

2. MAP1B - Knock Down @ 0.5x
   Composite Score: 0.207
   Metrics:
     • Rescue: -0.0002 (✓ rescues)
     • Coverage: 26.0% cells improved
     • Effect: 0.027
     • Preservation: 0.854
   Rationale: Knockdown to 50% expression: mild rescue effect (Δ=0.00), limited coverage (26%), excellent identity preservation.

3. KHDRBS2 - Knock Down @ 0.5x
   Composite Score: 0.190
   Metrics:
     • Rescue: -0.0029 (✓ rescues)
     • Coverage: 18.0% cells improved
     • Effect: 0.012
     • Preservation: 0.900
   Ratio