# üåê Mechano-Velocity: Notebook 03 - Graph Simulation

**Build Spatial Graph and Correct Velocity Vectors**

This notebook:
1. Builds the spatial hexagonal graph
2. Computes RNA velocity using scVelo
3. Applies physics-based resistance correction
4. Identifies trapped cells
5. Visualizes corrected velocity streamplots

---

## The Correction Equation

**Edge Weight:**
$$W_{ij} = \text{Similarity}(i, j) \times (1 - R_j)$$

**Corrected Velocity:**
$$\vec{v}_{corrected}^{(i)} = \sum_{j \in \text{neighbors}} W_{ij} \cdot (\vec{x}_j - \vec{x}_i)$$

## 1. Setup

In [None]:
# Check environment
import sys
IN_COLAB = 'google.colab' in sys.modules

if IN_COLAB:
    %cd /content/mechano-velocity
    !pip install -q scanpy scvelo squidpy torch-geometric

In [None]:
# Core imports
import scanpy as sc
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from scipy import sparse
import warnings
warnings.filterwarnings('ignore')

sc.settings.verbosity = 2
sc.settings.set_figure_params(dpi=100, frameon=False, figsize=(8, 8))

# Try to import scvelo
try:
    import scvelo as scv
    scv.settings.verbosity = 2
    SCVELO_AVAILABLE = True
    print(f"scVelo version: {scv.__version__}")
except ImportError:
    SCVELO_AVAILABLE = False
    print("scVelo not available - will use spatial-only velocity")

In [None]:
# Import project modules
from pathlib import Path
sys.path.insert(0, '.')

from mechano_velocity import (
    Config, GraphBuilder, VelocityCorrector, Visualizer
)

PROJECT_ROOT = Path('.').resolve()
print(f"Project: {PROJECT_ROOT}")

## 2. Load Mechanotyped Data

In [None]:
# Load configuration
config = Config()
config.output_dir = PROJECT_ROOT / "output"

# Load mechanotyped data
adata_path = config.output_dir / 'mechanotyped_adata.h5ad'

if adata_path.exists():
    adata = sc.read_h5ad(adata_path)
    print(f"Loaded: {adata.shape}")
else:
    raise FileNotFoundError(f"Please run 02_Mechanotyping.ipynb first.")

In [None]:
# Verify resistance is computed
if 'resistance' not in adata.obs.columns:
    raise ValueError("Resistance not found. Run mechanotyping first.")

print("\nData Overview:")
print(f"  Spots: {adata.n_obs}")
print(f"  Resistance range: [{adata.obs['resistance'].min():.3f}, {adata.obs['resistance'].max():.3f}]")
print(f"  Layers: {list(adata.layers.keys())}")

## 3. Build Spatial Graph

In [None]:
# Initialize graph builder
graph_builder = GraphBuilder(config)

In [None]:
# Build spatial graph
# - KNN method with 6 neighbors (hexagonal grid)
# - Include resistance weighting
# - Include expression similarity

adjacency = graph_builder.build_spatial_graph(
    adata,
    method='knn',
    k_neighbors=6,
    include_resistance=True,
    include_similarity=True
)

In [None]:
# View graph statistics
metrics = graph_builder.metrics

print("\nüìä Graph Statistics:")
print(f"  Nodes: {metrics.n_nodes}")
print(f"  Edges: {metrics.n_edges}")
print(f"  Avg degree: {metrics.avg_degree:.2f}")
print(f"  Connectivity: {metrics.connectivity:.2%}")
print(f"  Avg edge weight: {metrics.avg_edge_weight:.3f}")

In [None]:
# Visualize edge weights
edge_weights = adjacency.data

fig, ax = plt.subplots(figsize=(10, 5))
ax.hist(edge_weights, bins=50, color='steelblue', edgecolor='white', alpha=0.7)
ax.axvline(x=edge_weights.mean(), color='red', linestyle='--', 
           label=f'Mean: {edge_weights.mean():.3f}')
ax.set_xlabel('Edge Weight', fontsize=12)
ax.set_ylabel('Frequency', fontsize=12)
ax.set_title('Distribution of Edge Weights\n(Higher = Easier Transition)', fontsize=14)
ax.legend()
plt.tight_layout()
plt.savefig(config.output_dir / 'edge_weight_distribution.png', dpi=150)
plt.show()

## 4. Compute RNA Velocity (Optional)

If scVelo is available, we compute RNA velocity from spliced/unspliced counts.
Otherwise, we use spatial-only velocity based on the graph.

In [None]:
# Check for spliced/unspliced layers
has_velocity_layers = 'spliced' in adata.layers or 'Ms' in adata.layers

print(f"Has velocity layers: {has_velocity_layers}")
print(f"scVelo available: {SCVELO_AVAILABLE}")

In [None]:
# Initialize velocity corrector
velocity_corrector = VelocityCorrector(config)

In [None]:
if SCVELO_AVAILABLE and has_velocity_layers:
    # Compute RNA velocity
    print("Computing RNA velocity with scVelo...")
    velocity_corrector.compute_rna_velocity(adata, mode='stochastic')
else:
    print("Skipping RNA velocity (spatial-only mode)")
    print("Will use spatial graph structure for velocity computation.")

## 5. Apply Resistance Correction

In [None]:
# Apply physics-based correction
# This projects velocity onto allowed directions (away from walls)

corrected_velocity = velocity_corrector.apply_resistance_correction(
    adata,
    graph_builder=graph_builder,
    method='projection'  # Options: 'projection', 'scaling', 'hard_threshold'
)

In [None]:
# View results
print("\nüìç Corrected Velocity:")
print(f"  Shape: {corrected_velocity.shape}")
print(f"  Stored in: adata.obsm['velocity_corrected']")
print(f"  Magnitude in: adata.obs['velocity_magnitude']")

In [None]:
# Compare velocities
comparison = velocity_corrector.compare_velocities(adata)

print("\nüìà Velocity Analysis:")
for key, value in comparison.items():
    if isinstance(value, float):
        print(f"  {key}: {value:.4f}")
    else:
        print(f"  {key}: {value}")

## 6. Identify Trapped Cells

In [None]:
# Identify cells that are trapped (high resistance, low velocity)
# These are cells that WANT to move but CANNOT

trapped_mask = velocity_corrector.identify_trapped_cells(
    adata,
    velocity_threshold=0.01,
    resistance_threshold=0.8
)

In [None]:
# Visualize trapped cells
fig, axes = plt.subplots(1, 2, figsize=(14, 6))

# Velocity magnitude
sc.pl.spatial(
    adata, color='velocity_magnitude', ax=axes[0], show=False,
    title='Velocity Magnitude', cmap='viridis'
)

# Trapped cells
sc.pl.spatial(
    adata, color='is_trapped', ax=axes[1], show=False,
    title='Trapped Cells (High R, Low V)',
    palette={True: 'red', False: 'lightgray'}
)

plt.tight_layout()
plt.savefig(config.output_dir / 'trapped_cells.png', dpi=150)
plt.show()

## 7. Visualize Velocity Fields

In [None]:
# Initialize visualizer
viz = Visualizer(config)

In [None]:
# Plot velocity arrows (quiver plot)
fig = viz.plot_velocity_arrows(
    adata,
    velocity_key='velocity_corrected',
    color_by='resistance',
    arrow_scale=0.5,
    subsample=1,
    title='Physics-Constrained Velocity Vectors',
    save_path=config.output_dir / 'velocity_arrows.png'
)
plt.show()

In [None]:
# Plot streamplot
fig = viz.plot_velocity_streamplot(
    adata,
    velocity_key='velocity_corrected',
    color_by='resistance',
    grid_resolution=30,
    title='Velocity Streamlines (Colored by Resistance)',
    save_path=config.output_dir / 'velocity_streamplot.png'
)
plt.show()

In [None]:
# Three-panel overview
fig = viz.plot_comparison(
    adata,
    save_path=config.output_dir / 'analysis_overview.png'
)
plt.show()

## 8. Flow Analysis

In [None]:
# Analyze flow patterns by cluster
if 'leiden' in adata.obs.columns:
    # Calculate mean velocity magnitude per cluster
    cluster_velocity = adata.obs.groupby('leiden').agg({
        'velocity_magnitude': 'mean',
        'resistance': 'mean',
        'is_trapped': 'sum'
    }).sort_values('velocity_magnitude', ascending=False)
    
    cluster_velocity.columns = ['Avg Velocity', 'Avg Resistance', 'N Trapped']
    
    print("\nüî¨ Flow Analysis by Cluster:")
    print(cluster_velocity.to_string())

In [None]:
# Visualize cluster flow patterns
if 'leiden' in adata.obs.columns:
    fig, axes = plt.subplots(1, 2, figsize=(14, 6))
    
    # Bar plot
    cluster_velocity['Avg Velocity'].plot(
        kind='bar', ax=axes[0], color='steelblue', edgecolor='white'
    )
    axes[0].set_xlabel('Cluster')
    axes[0].set_ylabel('Mean Velocity Magnitude')
    axes[0].set_title('Velocity by Cluster')
    
    # Scatter: Velocity vs Resistance
    axes[1].scatter(
        cluster_velocity['Avg Resistance'],
        cluster_velocity['Avg Velocity'],
        s=100, c='steelblue', alpha=0.7
    )
    for idx, row in cluster_velocity.iterrows():
        axes[1].annotate(str(idx), (row['Avg Resistance'], row['Avg Velocity']),
                        fontsize=10, ha='center', va='bottom')
    axes[1].set_xlabel('Mean Resistance')
    axes[1].set_ylabel('Mean Velocity')
    axes[1].set_title('Velocity vs Resistance by Cluster')
    
    plt.tight_layout()
    plt.savefig(config.output_dir / 'cluster_flow_analysis.png', dpi=150)
    plt.show()

## 9. Convert to PyTorch Geometric (Optional)

For GNN training, convert the graph to PyTorch Geometric format.

In [None]:
try:
    import torch
    HAS_TORCH = True
except ImportError:
    HAS_TORCH = False
    print("PyTorch not available - skipping PyG conversion")

In [None]:
if HAS_TORCH:
    try:
        # Convert to PyG format
        pyg_data = graph_builder.to_pytorch_geometric(adata)
        
        print("\nüî• PyTorch Geometric Data:")
        print(f"  {pyg_data}")
        print(f"  Node features: {pyg_data.x.shape}")
        print(f"  Edges: {pyg_data.edge_index.shape}")
        print(f"  Has resistance: {hasattr(pyg_data, 'resistance')}")
        
        # Save PyG data
        torch.save(pyg_data, config.output_dir / 'spatial_graph.pt')
        print(f"\nSaved PyG data to: {config.output_dir / 'spatial_graph.pt'}")
        
    except ImportError:
        print("torch-geometric not installed - run: pip install torch-geometric")

## 10. Save Results

In [None]:
# Save final AnnData with all computed values
output_path = config.output_dir / 'velocity_corrected_adata.h5ad'
adata.write_h5ad(output_path)
print(f"Saved: {output_path}")

In [None]:
# Export velocity vectors to CSV
velocity_df = pd.DataFrame({
    'spot_id': adata.obs_names,
    'x': adata.obsm['spatial'][:, 0],
    'y': adata.obsm['spatial'][:, 1],
    'velocity_x': adata.obsm['velocity_corrected'][:, 0],
    'velocity_y': adata.obsm['velocity_corrected'][:, 1],
    'velocity_magnitude': adata.obs['velocity_magnitude'].values,
    'resistance': adata.obs['resistance'].values,
    'is_trapped': adata.obs['is_trapped'].values,
})

csv_path = config.output_dir / 'velocity_vectors.csv'
velocity_df.to_csv(csv_path, index=False)
print(f"Exported: {csv_path}")

## Summary

‚úÖ Built spatial hexagonal graph with 6 neighbors  
‚úÖ Computed resistance-weighted edge values  
‚úÖ Applied physics-based velocity correction  
‚úÖ Identified trapped cells (high R, low V)  
‚úÖ Visualized corrected velocity field  
‚úÖ Analyzed flow patterns by cluster  
‚úÖ Converted to PyTorch Geometric format  
‚úÖ Saved all results  

**Next: Run `04_Training_Validation.ipynb` for GNN training and clinical scoring.**