# Voronoi-based Void Detection: MLP vs GNN

This notebook demonstrates how to use Voronoi tessellation features to detect cosmic voids while avoiding spatial data leakage.

## Key Concepts

1. **Spatial Data Leakage**: When models memorize galaxy positions rather than learning void characteristics
2. **Voronoi Features**: Topological properties (volume, neighbor count) that avoid leakage
3. **MLP vs GNN**: Comparing simple feature-based vs graph-based approaches

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
import warnings
warnings.filterwarnings('ignore')

# Import VoidX modules
from voidx.voronoi import VoronoiFeatureExtractor
from voidx.models import VoidMLP
from voidx.data import GalaxyDataset, split_indices_stratified, normalize_features

# Check if GNN is available
try:
    from voidx.models import VoronoiGCN, VoronoiGAT
    from torch_geometric.data import Data
    GNN_AVAILABLE = True
except ImportError:
    GNN_AVAILABLE = False
    print("torch_geometric not installed. GNN models will not be available.")
    print("Install with: pip install torch-geometric")

import torch

## 1. Load Data

Load galaxy positions and void labels. If you have prepared data from `data_preparation_synthetic.ipynb`, load it here.

In [None]:
# Try to load existing data
data_path = Path('data/knn_data.npz')

if data_path.exists():
    print(f"Loading data from {data_path}...")
    data = np.load(data_path)
    positions = data['positions']
    labels = data['membership'].astype(np.float32)
    print(f"Loaded {len(positions)} galaxies")
    print(f"Void fraction: {labels.mean():.2%}")
else:
    print("No existing data found. Generating synthetic data...")
    # Generate simple synthetic data
    n_galaxies = 5000
    box_size = 250.0
    
    # Background galaxies (not in voids)
    n_bg = int(n_galaxies * 0.7)
    positions_bg = np.random.uniform(0, box_size, size=(n_bg, 3))
    labels_bg = np.zeros(n_bg)
    
    # Void galaxies
    n_void = n_galaxies - n_bg
    positions_void = np.random.uniform(0, box_size, size=(n_void, 3))
    labels_void = np.ones(n_void)
    
    positions = np.vstack([positions_bg, positions_void])
    labels = np.hstack([labels_bg, labels_void]).astype(np.float32)
    
    # Shuffle
    idx = np.random.permutation(n_galaxies)
    positions = positions[idx]
    labels = labels[idx]
    
    print(f"Generated {n_galaxies} galaxies ({labels.mean():.2%} in voids)")

## 2. Compute Voronoi Features

Extract Voronoi tessellation features: cell volumes, neighbor counts, and adjacency graph.

In [None]:
# Create Voronoi feature extractor
box_size = 250.0  # Adjust based on your data
extractor = VoronoiFeatureExtractor(
    box_size=box_size,
    use_periodic=True,
    clip_infinite=True,
)

print("Computing Voronoi tessellation...")
features = extractor.extract_features(positions)

print(f"\nVoronoi statistics:")
print(f"  Number of cells: {len(positions)}")
print(f"  Number of edges: {features['edge_index'].shape[1]}")
print(f"  Avg neighbors per cell: {features['neighbor_count'].mean():.2f}")
print(f"  Median cell volume: {np.nanmedian(features['volumes']):.2f}")

### Visualize Features

Let's look at the distribution of volumes and neighbor counts for void vs non-void galaxies.

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(12, 4))

# Volume distribution
void_mask = labels == 1
nonvoid_mask = labels == 0

volumes = features['volumes']
valid_mask = ~np.isnan(volumes)

ax = axes[0]
ax.hist(np.log10(volumes[valid_mask & nonvoid_mask]), bins=50, alpha=0.5, label='Non-void', density=True)
ax.hist(np.log10(volumes[valid_mask & void_mask]), bins=50, alpha=0.5, label='Void', density=True)
ax.set_xlabel('Log10(Cell Volume)')
ax.set_ylabel('Density')
ax.set_title('Cell Volume Distribution')
ax.legend()
ax.grid(alpha=0.3)

# Neighbor count distribution
ax = axes[1]
ax.hist(features['neighbor_count'][nonvoid_mask], bins=30, alpha=0.5, label='Non-void', density=True)
ax.hist(features['neighbor_count'][void_mask], bins=30, alpha=0.5, label='Void', density=True)
ax.set_xlabel('Number of Neighbors')
ax.set_ylabel('Density')
ax.set_title('Neighbor Count Distribution')
ax.legend()
ax.grid(alpha=0.3)

plt.tight_layout()
plt.show()

print("\nFeature statistics:")
print(f"Void cells - Mean log(volume): {np.log10(volumes[valid_mask & void_mask]).mean():.2f}")
print(f"Non-void cells - Mean log(volume): {np.log10(volumes[valid_mask & nonvoid_mask]).mean():.2f}")
print(f"Void cells - Mean neighbors: {features['neighbor_count'][void_mask].mean():.2f}")
print(f"Non-void cells - Mean neighbors: {features['neighbor_count'][nonvoid_mask].mean():.2f}")

## 3. Prepare Data for Training

Split data into train/validation/test sets using stratified splitting to preserve class balance.

In [None]:
# Split data
train_idx, val_idx, test_idx = split_indices_stratified(
    labels, train=0.7, val=0.15, seed=42
)

print(f"Data split:")
print(f"  Train: {len(train_idx)} samples ({labels[train_idx].mean():.2%} void)")
print(f"  Val:   {len(val_idx)} samples ({labels[val_idx].mean():.2%} void)")
print(f"  Test:  {len(test_idx)} samples ({labels[test_idx].mean():.2%} void)")

## 4. Train MLP Model

Train an MLP using only topological features (volume + neighbor count) to avoid spatial leakage.

In [None]:
# Create MLP features (NO positions!)
mlp_features = extractor.create_mlp_features(positions, include_positions=False)

X_train = mlp_features[train_idx]
X_val = mlp_features[val_idx]
X_test = mlp_features[test_idx]

# Normalize
X_train, X_val, X_test, _ = normalize_features(X_train, X_val, X_test)

y_train = labels[train_idx]
y_val = labels[val_idx]
y_test = labels[test_idx]

print(f"MLP features shape: {mlp_features.shape}")
print(f"Features: normalized_volume, neighbor_count")

In [None]:
# Create datasets and loaders
train_dataset = GalaxyDataset(X_train, y_train)
val_dataset = GalaxyDataset(X_val, y_val)
test_dataset = GalaxyDataset(X_test, y_test)

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=64, shuffle=False)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=64, shuffle=False)

# Create model
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using device: {device}")

mlp_model = VoidMLP(
    in_features=X_train.shape[1],
    hidden_layers=(128, 64, 32),
    dropout=0.3,
).to(device)

print(f"\nModel parameters: {sum(p.numel() for p in mlp_model.parameters()):,}")

In [None]:
# Training setup
optimizer = torch.optim.Adam(mlp_model.parameters(), lr=0.001, weight_decay=1e-5)
criterion = torch.nn.BCELoss()

# Training loop
epochs = 100
best_val_loss = float('inf')
train_losses = []
val_losses = []
val_accs = []

print("Training MLP...")
for epoch in range(epochs):
    # Training
    mlp_model.train()
    train_loss = 0.0
    for X_batch, y_batch in train_loader:
        X_batch, y_batch = X_batch.to(device), y_batch.to(device)
        
        optimizer.zero_grad()
        pred = mlp_model(X_batch)
        loss = criterion(pred, y_batch)
        loss.backward()
        optimizer.step()
        
        train_loss += loss.item()
    
    train_loss /= len(train_loader)
    train_losses.append(train_loss)
    
    # Validation
    mlp_model.eval()
    val_loss = 0.0
    correct = 0
    total = 0
    
    with torch.no_grad():
        for X_batch, y_batch in val_loader:
            X_batch, y_batch = X_batch.to(device), y_batch.to(device)
            pred = mlp_model(X_batch)
            loss = criterion(pred, y_batch)
            val_loss += loss.item()
            
            predicted = (pred > 0.5).float()
            correct += (predicted == y_batch).sum().item()
            total += y_batch.size(0)
    
    val_loss /= len(val_loader)
    val_acc = correct / total
    val_losses.append(val_loss)
    val_accs.append(val_acc)
    
    if (epoch + 1) % 10 == 0:
        print(f"Epoch {epoch+1:3d}: Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}")
    
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        best_model_state = mlp_model.state_dict().copy()

# Load best model
mlp_model.load_state_dict(best_model_state)
print("\nTraining complete!")

In [None]:
# Plot training history
fig, axes = plt.subplots(1, 2, figsize=(12, 4))

ax = axes[0]
ax.plot(train_losses, label='Train')
ax.plot(val_losses, label='Validation')
ax.set_xlabel('Epoch')
ax.set_ylabel('Loss')
ax.set_title('Training History - Loss')
ax.legend()
ax.grid(alpha=0.3)

ax = axes[1]
ax.plot(val_accs)
ax.set_xlabel('Epoch')
ax.set_ylabel('Accuracy')
ax.set_title('Validation Accuracy')
ax.grid(alpha=0.3)

plt.tight_layout()
plt.show()

In [None]:
# Evaluate on test set
mlp_model.eval()
test_preds = []
test_labels = []

with torch.no_grad():
    for X_batch, y_batch in test_loader:
        X_batch = X_batch.to(device)
        pred = mlp_model(X_batch)
        test_preds.append(pred.cpu().numpy())
        test_labels.append(y_batch.numpy())

test_preds = np.concatenate(test_preds)
test_labels = np.concatenate(test_labels)
test_pred_labels = (test_preds > 0.5).astype(int)

# Calculate metrics
accuracy = (test_pred_labels == test_labels).mean()
tp = ((test_pred_labels == 1) & (test_labels == 1)).sum()
fp = ((test_pred_labels == 1) & (test_labels == 0)).sum()
fn = ((test_pred_labels == 0) & (test_labels == 1)).sum()
tn = ((test_pred_labels == 0) & (test_labels == 0)).sum()

precision = tp / (tp + fp) if (tp + fp) > 0 else 0
recall = tp / (tp + fn) if (tp + fn) > 0 else 0
f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0

print("\nMLP Test Results:")
print(f"  Accuracy:  {accuracy:.4f}")
print(f"  Precision: {precision:.4f}")
print(f"  Recall:    {recall:.4f}")
print(f"  F1 Score:  {f1:.4f}")

## 5. Train GNN Model (Optional)

If torch_geometric is installed, train a GNN model that leverages the graph structure.

In [None]:
if GNN_AVAILABLE:
    # Create GNN features (include positions in graph context)
    gnn_features = np.hstack([
        features['normalized_volumes'][:, np.newaxis],
        features['neighbor_count'][:, np.newaxis],
        positions,  # OK in GNN due to graph context
    ])
    
    # Normalize
    X_train_gnn = gnn_features[train_idx]
    X_val_gnn = gnn_features[val_idx]
    X_test_gnn = gnn_features[test_idx]
    X_train_gnn, X_val_gnn, X_test_gnn, (mean, std) = normalize_features(
        X_train_gnn, X_val_gnn, X_test_gnn
    )
    
    # Normalize full dataset
    gnn_features_norm = (gnn_features - mean) / std
    
    # Create masks
    train_mask = np.zeros(len(labels), dtype=bool)
    val_mask = np.zeros(len(labels), dtype=bool)
    test_mask = np.zeros(len(labels), dtype=bool)
    train_mask[train_idx] = True
    val_mask[val_idx] = True
    test_mask[test_idx] = True
    
    print(f"GNN features shape: {gnn_features.shape}")
    print(f"Features: normalized_volume, neighbor_count, x, y, z")
    print(f"Note: Positions used in graph context")
else:
    print("Skipping GNN training - torch_geometric not available")

In [None]:
if GNN_AVAILABLE:
    # Create graph data
    data = Data(
        x=torch.as_tensor(gnn_features_norm, dtype=torch.float32),
        edge_index=torch.as_tensor(features['edge_index'], dtype=torch.long),
        y=torch.as_tensor(labels, dtype=torch.float32),
    ).to(device)
    
    train_mask_t = torch.as_tensor(train_mask, dtype=torch.bool).to(device)
    val_mask_t = torch.as_tensor(val_mask, dtype=torch.bool).to(device)
    test_mask_t = torch.as_tensor(test_mask, dtype=torch.bool).to(device)
    
    # Create model
    gnn_model = VoronoiGCN(
        in_features=gnn_features.shape[1],
        hidden_channels=64,
        num_layers=3,
        dropout=0.3,
    ).to(device)
    
    print(f"\nGNN model parameters: {sum(p.numel() for p in gnn_model.parameters()):,}")
    
    # Training setup
    optimizer_gnn = torch.optim.Adam(gnn_model.parameters(), lr=0.001, weight_decay=1e-5)
    criterion_gnn = torch.nn.BCEWithLogitsLoss()
    
    # Training loop
    epochs_gnn = 200
    best_val_loss_gnn = float('inf')
    train_losses_gnn = []
    val_losses_gnn = []
    val_accs_gnn = []
    
    print("\nTraining GNN...")
    for epoch in range(epochs_gnn):
        # Training
        gnn_model.train()
        optimizer_gnn.zero_grad()
        
        out = gnn_model(data.x, data.edge_index)
        loss = criterion_gnn(out[train_mask_t], data.y[train_mask_t])
        loss.backward()
        optimizer_gnn.step()
        
        train_losses_gnn.append(loss.item())
        
        # Validation
        gnn_model.eval()
        with torch.no_grad():
            out = gnn_model(data.x, data.edge_index)
            val_loss = criterion_gnn(out[val_mask_t], data.y[val_mask_t]).item()
            
            pred = torch.sigmoid(out[val_mask_t]) > 0.5
            val_acc = (pred == data.y[val_mask_t]).float().mean().item()
        
        val_losses_gnn.append(val_loss)
        val_accs_gnn.append(val_acc)
        
        if (epoch + 1) % 20 == 0:
            print(f"Epoch {epoch+1:3d}: Train Loss: {loss.item():.4f}, Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}")
        
        if val_loss < best_val_loss_gnn:
            best_val_loss_gnn = val_loss
            best_model_state_gnn = gnn_model.state_dict().copy()
    
    # Load best model
    gnn_model.load_state_dict(best_model_state_gnn)
    print("\nGNN Training complete!")

In [None]:
if GNN_AVAILABLE:
    # Plot GNN training history
    fig, axes = plt.subplots(1, 2, figsize=(12, 4))
    
    ax = axes[0]
    ax.plot(train_losses_gnn, label='Train', alpha=0.7)
    ax.plot(val_losses_gnn, label='Validation')
    ax.set_xlabel('Epoch')
    ax.set_ylabel('Loss')
    ax.set_title('GNN Training History - Loss')
    ax.legend()
    ax.grid(alpha=0.3)
    
    ax = axes[1]
    ax.plot(val_accs_gnn)
    ax.set_xlabel('Epoch')
    ax.set_ylabel('Accuracy')
    ax.set_title('GNN Validation Accuracy')
    ax.grid(alpha=0.3)
    
    plt.tight_layout()
    plt.show()
    
    # Evaluate GNN on test set
    gnn_model.eval()
    with torch.no_grad():
        out = gnn_model(data.x, data.edge_index)
        test_preds_gnn = torch.sigmoid(out[test_mask_t]).cpu().numpy()
        test_labels_gnn = data.y[test_mask_t].cpu().numpy()
    
    test_pred_labels_gnn = (test_preds_gnn > 0.5).astype(int)
    
    # Calculate metrics
    accuracy_gnn = (test_pred_labels_gnn == test_labels_gnn).mean()
    tp_gnn = ((test_pred_labels_gnn == 1) & (test_labels_gnn == 1)).sum()
    fp_gnn = ((test_pred_labels_gnn == 1) & (test_labels_gnn == 0)).sum()
    fn_gnn = ((test_pred_labels_gnn == 0) & (test_labels_gnn == 1)).sum()
    tn_gnn = ((test_pred_labels_gnn == 0) & (test_labels_gnn == 0)).sum()
    
    precision_gnn = tp_gnn / (tp_gnn + fp_gnn) if (tp_gnn + fp_gnn) > 0 else 0
    recall_gnn = tp_gnn / (tp_gnn + fn_gnn) if (tp_gnn + fn_gnn) > 0 else 0
    f1_gnn = 2 * precision_gnn * recall_gnn / (precision_gnn + recall_gnn) if (precision_gnn + recall_gnn) > 0 else 0
    
    print("\nGNN Test Results:")
    print(f"  Accuracy:  {accuracy_gnn:.4f}")
    print(f"  Precision: {precision_gnn:.4f}")
    print(f"  Recall:    {recall_gnn:.4f}")
    print(f"  F1 Score:  {f1_gnn:.4f}")

## 6. Compare Results

Compare the performance of MLP and GNN models.

In [None]:
print("\n" + "="*60)
print("Model Comparison")
print("="*60)
print(f"{'Model':<15} {'Accuracy':>10} {'Precision':>10} {'Recall':>10} {'F1':>10}")
print("-" * 60)
print(f"{'MLP':<15} {accuracy:>10.4f} {precision:>10.4f} {recall:>10.4f} {f1:>10.4f}")

if GNN_AVAILABLE:
    print(f"{'GNN (GCN)':<15} {accuracy_gnn:>10.4f} {precision_gnn:>10.4f} {recall_gnn:>10.4f} {f1_gnn:>10.4f}")

print("\n" + "="*60)
print("Key Takeaways:")
print("="*60)
print("1. MLP uses only topological features (volume, neighbor count)")
print("   - Avoids spatial data leakage")
print("   - Good baseline performance")
print("")
if GNN_AVAILABLE:
    print("2. GNN leverages graph structure + node features")
    print("   - Can capture multi-scale patterns")
    print("   - Uses positions within graph context")
    print("   - Better performance on complex structures")
print("="*60)

## Next Steps

1. **Test spatial generalization**: Split data spatially (e.g., by box regions) to verify no leakage
2. **Try other GNN models**: VoronoiGAT (attention-based) or VoronoiSAGE
3. **Feature engineering**: Add more Voronoi features (cell shape, anisotropy, etc.)
4. **Hyperparameter tuning**: Optimize model architecture and training parameters
5. **Apply to real data**: Test on observational catalogs

For more details, see `VORONOI_GNN_GUIDE.md` and `examples/compare_mlp_gnn_voronoi.py`.