# Baseline ISP Analysis on Colab

This notebook runs the baseline In-Silico Perturbation analysis using Geneformer.

## Features:
- Loads Geneformer model
- Processes SciPlex2 perturbation data
- Generates embeddings
- Evaluates embedding quality (classification + geometry metrics)
- Optional wandb tracking


In [None]:
# Import required modules
import os
import sys
import numpy as np
import pandas as pd
import torch
import warnings
warnings.filterwarnings("ignore")

# Import our custom modules
from baseline_isp import GeneformerISPOptimizer

print("Imports successful!")
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")


## Configuration

Set your parameters here:


In [None]:
# Configuration
CONFIG = {
    'model_name': 'gf-6L-10M-i2048',  # Start with smaller model for Colab
    'num_perturbations': 100,  # Adjust based on Colab memory limits
    'device': 'cuda' if torch.cuda.is_available() else 'cpu',
    'use_wandb': False,  # Set to True if you want wandb tracking
    'wandb_project': 'ispo-colab',
    'output_dir': 'results/baseline'
}

print("Configuration:")
for key, value in CONFIG.items():
    print(f"  {key}: {value}")


## Initialize Optimizer


In [None]:
# Initialize the optimizer
optimizer = GeneformerISPOptimizer(
    model_name=CONFIG['model_name'],
    device=CONFIG['device'],
    use_wandb=CONFIG['use_wandb'],
    wandb_project=CONFIG['wandb_project'],
    wandb_run_name=f"{CONFIG['model_name']}_colab"
)

print(f"Optimizer initialized with model: {CONFIG['model_name']}")
print(f"   Device: {CONFIG['device']}")
print(f"   Wandb tracking: {CONFIG['use_wandb']}")


## Load Model


In [None]:
# Load the Geneformer model
# This will download model weights if not already cached
optimizer.load_model()

print("Model loaded successfully!")


## Load Perturbation Data


In [None]:
# Load SciPlex2 perturbation data
# This will download the data if not already present
perturbation_data = optimizer.load_perturbation_data(
    num_perturbations=CONFIG['num_perturbations']
)

print(f"Data loaded: {perturbation_data.shape}")
print(f"   Perturbation types: {perturbation_data.obs['perturbation'].value_counts().to_dict()}")


## Run Baseline Inference

This will:
1. Process the data
2. Generate embeddings
3. Evaluate embedding quality
4. Track performance metrics


In [None]:
# Run baseline inference with evaluation
results = optimizer.run_baseline_inference(
    perturbation_data,
    output_dir=CONFIG['output_dir']
)

print("Inference completed!")
print(f"   Embeddings shape: {results['embeddings'].shape}")
print(f"   Total time: {results['performance_metrics']['total_time_seconds']:.2f}s")
print(f"   Throughput: {CONFIG['num_perturbations'] / results['performance_metrics']['total_time_seconds']:.2f} samples/s")


## View Results


In [None]:
# Display performance metrics
print("=" * 60)
print("PERFORMANCE METRICS")
print("=" * 60)
perf_metrics = results['performance_metrics']
for key, value in perf_metrics.items():
    if isinstance(value, (int, float)):
        print(f"{key:30s}: {value:.4f}")
    else:
        print(f"{key:30s}: {value}")

print("\n" + "=" * 60)
print("EVALUATION METRICS (Classification)")
print("=" * 60)
if results['evaluation_metrics']:
    eval_metrics = results['evaluation_metrics']
    for key, value in eval_metrics.items():
        if isinstance(value, (int, float)):
            print(f"{key:30s}: {value:.4f}")
        else:
            print(f"{key:30s}: {value}")
else:
    print("No evaluation metrics available")

print("\n" + "=" * 60)
print("GEOMETRY METRICS (Clustering & Separation)")
print("=" * 60)
if results['geometry_metrics']:
    geom_metrics = results['geometry_metrics']
    for key, value in geom_metrics.items():
        if value is not None:
            if isinstance(value, (int, float)):
                print(f"{key:30s}: {value:.4f}")
            else:
                print(f"{key:30s}: {value}")
else:
    print("No geometry metrics available")


In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.decomposition import PCA
import umap

# Set style
sns.set_style("whitegrid")
plt.rcParams['figure.figsize'] = (12, 5)

# Get embeddings and labels
embeddings = results['embeddings']
labels = perturbation_data.obs['perturbation'].values

# Create subplots
fig, axes = plt.subplots(1, 2, figsize=(16, 6))

# 1. PCA visualization
pca = PCA(n_components=2, random_state=42)
embeddings_pca = pca.fit_transform(embeddings)

unique_labels = np.unique(labels)
colors = plt.cm.tab10(np.linspace(0, 1, len(unique_labels)))
label_to_color = dict(zip(unique_labels, colors))

for label in unique_labels:
    mask = labels == label
    axes[0].scatter(
        embeddings_pca[mask, 0],
        embeddings_pca[mask, 1],
        label=label,
        alpha=0.6,
        s=20,
        c=[label_to_color[label]]
    )

axes[0].set_title('PCA Visualization of Embeddings', fontsize=14, fontweight='bold')
axes[0].set_xlabel(f'PC1 ({pca.explained_variance_ratio_[0]:.2%} variance)')
axes[0].set_ylabel(f'PC2 ({pca.explained_variance_ratio_[1]:.2%} variance)')
axes[0].legend(bbox_to_anchor=(1.05, 1), loc='upper left')
axes[0].grid(True, alpha=0.3)

# 2. UMAP visualization
reducer = umap.UMAP(n_components=2, random_state=42, n_neighbors=15, min_dist=0.1)
embeddings_umap = reducer.fit_transform(embeddings)

for label in unique_labels:
    mask = labels == label
    axes[1].scatter(
        embeddings_umap[mask, 0],
        embeddings_umap[mask, 1],
        label=label,
        alpha=0.6,
        s=20,
        c=[label_to_color[label]]
    )

axes[1].set_title('UMAP Visualization of Embeddings', fontsize=14, fontweight='bold')
axes[1].set_xlabel('UMAP 1')
axes[1].set_ylabel('UMAP 2')
axes[1].legend(bbox_to_anchor=(1.05, 1), loc='upper left')
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

print(f"Visualizations created!")
print(f"   PCA explained variance: {pca.explained_variance_ratio_[:2].sum():.2%}")


In [None]:
# Create performance metrics bar chart
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# 1. Resource usage
perf_metrics = results['performance_metrics']
resource_data = {
    'CPU (%)': [perf_metrics.get('avg_cpu_percent', 0), perf_metrics.get('max_cpu_percent', 0)],
    'Memory (%)': [perf_metrics.get('avg_memory_percent', 0), perf_metrics.get('max_memory_percent', 0)],
}

if 'avg_gpu_percent' in perf_metrics:
    resource_data['GPU (%)'] = [perf_metrics.get('avg_gpu_percent', 0), perf_metrics.get('max_gpu_percent', 0)]

resource_df = pd.DataFrame(resource_data, index=['Average', 'Peak'])
resource_df.plot(kind='bar', ax=axes[0], color=['#1f77b4', '#ff7f0e', '#2ca02c'])
axes[0].set_title('Resource Usage', fontsize=14, fontweight='bold')
axes[0].set_ylabel('Usage (%)')
axes[0].legend()
axes[0].grid(True, alpha=0.3, axis='y')

# 2. Evaluation metrics
if results['evaluation_metrics']:
    eval_metrics = results['evaluation_metrics']
    eval_data = {
        'Accuracy': [eval_metrics.get('zero_shot_accuracy', 0)],
        'F1 Score': [eval_metrics.get('zero_shot_f1_score', 0)]
    }
    eval_df = pd.DataFrame(eval_data)
    eval_df.plot(kind='bar', ax=axes[1], color=['#9467bd', '#8c564b'], width=0.6)
    axes[1].set_title('Classification Performance', fontsize=14, fontweight='bold')
    axes[1].set_ylabel('Score')
    axes[1].set_ylim(0, 1)
    axes[1].set_xticklabels(['Baseline'], rotation=0)
    axes[1].legend()
    axes[1].grid(True, alpha=0.3, axis='y')
else:
    axes[1].text(0.5, 0.5, 'No evaluation metrics', 
                 ha='center', va='center', transform=axes[1].transAxes)
    axes[1].set_title('Classification Performance', fontsize=14, fontweight='bold')

plt.tight_layout()
plt.show()


## Save Results Summary


In [None]:
# Save results summary
optimizer.save_results_summary([results], "results/baseline_summary.csv")

print("Results summary saved to results/baseline_summary.csv")

# Display summary
summary_df = pd.read_csv("results/baseline_summary.csv")
print("\nSummary DataFrame:")
display(summary_df)


## Finish Wandb Run (if enabled)


In [None]:
# Finish wandb run if it was enabled
if CONFIG['use_wandb'] and optimizer.wandb_run is not None:
    optimizer.wandb_run.finish()
    print("✅ Wandb run completed. Check your dashboard!")
else:
    print("ℹ️ Wandb tracking was not enabled")


## Download Results (Optional)

If you want to download the results files, uncomment and run the cell below.


In [None]:
# Download results files
from google.colab import files
import zipfile
import os

# Create a zip file of results
results_dir = CONFIG['output_dir']
if os.path.exists(results_dir):
    zip_path = 'results_baseline.zip'
    with zipfile.ZipFile(zip_path, 'w', zipfile.ZIP_DEFLATED) as zipf:
        for root, dirs, files_list in os.walk(results_dir):
            for file in files_list:
                file_path = os.path.join(root, file)
                zipf.write(file_path, os.path.relpath(file_path, results_dir))
    
    # Download the zip file
    files.download(zip_path)
    print("✅ Results downloaded!")
else:
    print("⚠️ Results directory not found")
