# Sperm Quantification Pipeline - Complete Demo

This notebook demonstrates the end-to-end sperm quantification pipeline:
1. Generate synthetic data using Active Brownian Particle physics
2. Detect and track sperm through video frames
3. Compute WHO-standardized motility metrics
4. Visualize trajectories and analyze results

In [None]:
import sys
sys.path.append('../src')

import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path

# Import pipeline modules
from simulation import (
    ABPParameters,
    MultiParticleABP,
    generate_synthetic_dataset
)
from detection import BlobDetector
from tracking import SpermTracker
from metrics import analyze_single_trajectory
from visualization import (
    plot_trajectories,
    plot_velocity_distributions,
    plot_MSD_curve
)

print("✓ Imports successful")

## 1. Generate Synthetic Sperm Trajectories

We'll use an Active Brownian Particle (ABP) model to simulate realistic sperm swimming.

In [None]:
# Define simulation parameters
params = ABPParameters(
    v0=50.0,      # Self-propulsion speed (μm/s)
    Dr=0.5,       # Rotational diffusion (rad²/s)
    Dt=1.0,       # Translational diffusion (μm²/s)
    dt=0.033,     # Time step (s) - 30 fps
    width=500.0,  # Domain width (μm)
    height=500.0, # Domain height (μm)
    boundary='reflective'
)

# Generate synthetic video
n_particles = 20
duration = 5.0  # seconds

print(f"Simulating {n_particles} sperm for {duration} seconds...")

video, gt_trajectories, metadata = generate_synthetic_dataset(
    n_particles=n_particles,
    duration=duration,
    params=params,
    fps=30
)

print(f"✓ Generated {metadata['n_frames']} frames")
print(f"  Video shape: {video.shape}")
print(f"  Pixel size: {metadata['pixel_size_um']} μm/pixel")

### Visualize Sample Frame

In [None]:
fig, axes = plt.subplots(1, 3, figsize=(15, 5))

for idx, frame_idx in enumerate([0, len(video)//2, len(video)-1]):
    axes[idx].imshow(video[frame_idx], cmap='gray')
    axes[idx].set_title(f'Frame {frame_idx}')
    axes[idx].axis('off')

plt.tight_layout()
plt.show()

### Visualize Ground Truth Trajectories

In [None]:
# Extract xy trajectories
gt_traj_xy = [traj[:, :2] for traj in gt_trajectories]

plot_trajectories(
    gt_traj_xy,
    pixel_size_um=metadata['pixel_size_um'],
    title='Ground Truth Trajectories (Synthetic Data)'
)

## 2. Detection and Tracking

Now we'll process the synthetic video to detect and track sperm.

In [None]:
# Initialize detector and tracker
detector = BlobDetector(
    method='dog',
    min_sigma=1.5,
    max_sigma=3.0,
    threshold=0.1,
    min_area=10,
    max_area=200
)

tracker = SpermTracker(
    max_distance=30,
    max_gap=5,
    min_track_length=10,
    use_kalman=True,
    dt=1.0/metadata['fps']
)

print("Processing video...")

for frame_idx, frame in enumerate(video):
    # Detect sperm in current frame
    detections = detector.detect(frame)
    
    # Update tracker
    active_tracks = tracker.update(detections)
    
    if (frame_idx + 1) % 30 == 0:
        print(f"  Frame {frame_idx + 1}/{len(video)}: {len(active_tracks)} active tracks")

# Get all completed tracks
all_tracks = tracker.get_all_tracks()
print(f"\n✓ Tracking complete: {len(all_tracks)} tracks")

### Compare Tracked vs Ground Truth

In [None]:
tracked_trajectories = [track.get_trajectory() for track in all_tracks]

fig, axes = plt.subplots(1, 2, figsize=(16, 7))

# Ground truth
for traj in gt_traj_xy:
    axes[0].plot(traj[:, 0], traj[:, 1], alpha=0.7, linewidth=1.5)
axes[0].set_title('Ground Truth Trajectories', fontsize=14, fontweight='bold')
axes[0].set_xlabel('X (pixels)')
axes[0].set_ylabel('Y (pixels)')
axes[0].grid(True, alpha=0.3)

# Tracked
for traj in tracked_trajectories:
    axes[1].plot(traj[:, 0], traj[:, 1], alpha=0.7, linewidth=1.5)
axes[1].set_title(f'Tracked Trajectories (n={len(tracked_trajectories)})', fontsize=14, fontweight='bold')
axes[1].set_xlabel('X (pixels)')
axes[1].set_ylabel('Y (pixels)')
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

# Tracking accuracy
accuracy = len(tracked_trajectories) / len(gt_trajectories) * 100
print(f"Tracking accuracy: {accuracy:.1f}% of ground truth tracks recovered")

## 3. Compute Motility Metrics

Calculate WHO-standardized velocity metrics and physics-based trajectory analysis.

In [None]:
# Analyze all trajectories
print("Computing metrics for all tracks...")

all_metrics = []

for i, track in enumerate(all_tracks):
    trajectory = track.get_trajectory()
    
    metrics = analyze_single_trajectory(
        trajectory,
        fps=metadata['fps'],
        pixel_size_um=metadata['pixel_size_um']
    )
    
    metrics['track_id'] = track.track_id
    all_metrics.append(metrics)

print(f"✓ Analyzed {len(all_metrics)} trajectories")

### Display Sample Metrics

In [None]:
import pandas as pd

# Create DataFrame
df = pd.DataFrame(all_metrics)

# Display key metrics
key_metrics = ['track_id', 'VCL', 'VSL', 'VAP', 'LIN', 'WOB', 
               'motility_classification', 'diffusion_coefficient']

print("\nMotility Metrics Summary:")
print("=" * 80)
display(df[key_metrics].head(10))

# Summary statistics
print("\nSummary Statistics:")
print("=" * 80)
display(df[['VCL', 'VSL', 'LIN', 'diffusion_coefficient']].describe())

### Motility Classification

In [None]:
# Count by classification
classification_counts = df['motility_classification'].value_counts()

plt.figure(figsize=(8, 6))
classification_counts.plot(kind='bar', color='steelblue', edgecolor='black')
plt.title('Motility Classification Distribution', fontsize=14, fontweight='bold')
plt.xlabel('Classification')
plt.ylabel('Count')
plt.xticks(rotation=45)
plt.grid(True, alpha=0.3, axis='y')
plt.tight_layout()
plt.show()

print("\nClassification Breakdown:")
for cls, count in classification_counts.items():
    pct = count / len(df) * 100
    print(f"  {cls}: {count} ({pct:.1f}%)")

## 4. Visualization

Generate publication-quality plots.

### Velocity Distributions

In [None]:
plot_velocity_distributions(
    all_metrics,
    metric_names=['VCL', 'VSL', 'VAP', 'LIN', 'WOB', 'ALH']
)

### Mean Squared Displacement

In [None]:
from metrics.trajectory import compute_MSD, fit_MSD_diffusion

# Compute MSD for a sample trajectory
sample_traj = tracked_trajectories[0] * metadata['pixel_size_um']
lags, msd_values = compute_MSD(sample_traj, max_lag=50)

# Fit diffusion model
fit_params = fit_MSD_diffusion(lags, msd_values, dt=1.0/metadata['fps'])

plot_MSD_curve(lags, msd_values, metadata['fps'], fit_params=fit_params)

print(f"\nDiffusion Analysis:")
print(f"  D = {fit_params['D']:.2f} μm²/s")
print(f"  α = {fit_params['alpha']:.2f}")
print(f"  Regime: {fit_params['regime']}")

### Correlation Matrix

In [None]:
from visualization import plot_correlation_matrix

metric_names = ['VCL', 'VSL', 'VAP', 'LIN', 'WOB', 'diffusion_coefficient', 
                'persistence_length', 'directional_persistence']

plot_correlation_matrix(
    all_metrics,
    metric_names=metric_names,
    figsize=(10, 8)
)

## 5. Export Results

In [None]:
# Save metrics to CSV
output_dir = Path('../data/results')
output_dir.mkdir(parents=True, exist_ok=True)

csv_path = output_dir / 'metrics_synthetic.csv'
df.to_csv(csv_path, index=False)
print(f"✓ Metrics saved to {csv_path}")

# Save trajectories
import pickle

traj_path = output_dir / 'trajectories_synthetic.pkl'
with open(traj_path, 'wb') as f:
    pickle.dump({
        'trajectories': tracked_trajectories,
        'metadata': metadata,
        'parameters': params
    }, f)
print(f"✓ Trajectories saved to {traj_path}")

## Summary

This notebook demonstrated the complete sperm quantification pipeline:

✅ **Synthetic Data Generation**: Physics-based ABP simulation  
✅ **Detection**: Blob detection with configurable parameters  
✅ **Tracking**: Multi-object tracking with Kalman filtering  
✅ **Metrics**: WHO-standardized velocity metrics + physics analysis  
✅ **Visualization**: Publication-ready plots  
✅ **Export**: CSV and binary formats for further analysis  

### Next Steps

- Try with real microscopy data
- Adjust detection/tracking parameters for your data
- Compare X vs Y sperm populations (if labeled data available)
- Integrate with microfluidic analysis