In [None]:
import numpy as np
import torch
import scipy.ndimage as ndi
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import random

# ===============================
# STEP 1. Generate a 3D sphere segmentation
# ===============================
def generate_sphere_segmentation(shape=(64, 64, 64), radius=20, center=None):
    """
    Generates a 3D binary segmentation of a sphere.
    
    Args:
      shape: tuple (H, W, D) defining the volume dimensions.
      radius: radius of the sphere.
      center: center of the sphere (if None, placed at the center of volume).
      
    Returns:
      seg: numpy array of shape [H, W, D] with values 1 (inside) and 0 (outside).
    """
    if center is None:
        center = np.array(shape) / 2
    grid = np.indices(shape).transpose(1, 2, 3, 0).astype(np.float32)  # shape: (H, W, D, 3)
    # Compute Euclidean distance from center in every voxel.
    dists = np.linalg.norm(grid - center, axis=-1)
    seg = (dists <= radius).astype(np.float32)
    return seg

# ===============================
# STEP 2. Compute the SDF from the segmentation
# ===============================
def compute_target_sdf(seg_mask, clip_threshold=10):
    """
    Compute a signed distance field (SDF) from a binary segmentation mask.
    The SDF is computed as:
         sdf(x) = distance(x, background) - distance(x, vessel)
    so that points inside the vessel have negative values and those
    outside have positive values.
    
    Args:
      seg_mask: torch.Tensor of shape [B, 1, H, W, D] (binary mask).
      clip_threshold: maximum absolute value for SDF clipping.
      
    Returns:
      torch.Tensor of shape [B, 1, H, W, D] with SDF values.
    """
    seg_mask_np = seg_mask.cpu().numpy()  # shape: [B, 1, H, W, D]
    bin_mask = seg_mask_np > 0
    sdf_np = np.zeros_like(bin_mask, dtype=np.float32)
    for b in range(bin_mask.shape[0]):
        vol = bin_mask[b, 0]
        dt_out = ndi.distance_transform_edt(~vol)
        dt_in = ndi.distance_transform_edt(vol)
        sdf = dt_out - dt_in
        sdf_np[b, 0] = np.clip(sdf, -clip_threshold, clip_threshold)
        neg_count = np.sum(sdf < 0)
        pos_count = np.sum(sdf >= 0)
        total = sdf.size
        print(f"Shape {b}: Negative voxels: {neg_count} ({neg_count/total:.2%}), Positive voxels: {pos_count} ({pos_count/total:.2%})")
    return torch.from_numpy(sdf_np).to(seg_mask.device)

# ===============================
# STEP 3. Sample balanced narrow band points
# ===============================
def sample_sdf_balanced(seg_mask, sdf, num_points=25000, band_threshold=4):
    """
    Samples points exclusively from a narrow band where |SDF| < band_threshold,
    and returns a balanced set (equal numbers of negative and non-negative samples).
    
    Args:
      seg_mask: torch.Tensor of shape [1, 1, H, W, D].
      sdf: torch.Tensor of shape [1, 1, H, W, D].
      num_points: total number of points to sample.
      band_threshold: consider only voxels with |SDF| < band_threshold.
    
    Returns:
      coords: numpy array of shape [num_points, 3] with voxel coordinates.
      sdf_samples: numpy array of shape [num_points] with the SDF values at those coordinates.
    """
    # Remove batch and channel dimensions: shape becomes [H, W, D]
    sdf_np = sdf.squeeze().cpu().numpy()
    
    # Create a narrow band mask
    band_mask = np.abs(sdf_np) < band_threshold
    band_indices = np.argwhere(band_mask)
    if len(band_indices) == 0:
        raise ValueError("No voxels found in the narrow band. Increase band_threshold.")
    
    band_sdf = sdf_np[band_mask]
    
    # Split into negative (inside sphere) and non-negative (outside)
    neg_mask = band_sdf < 0
    pos_mask = band_sdf >= 0
    neg_indices = band_indices[neg_mask]
    pos_indices = band_indices[pos_mask]
    
    if len(neg_indices) == 0 or len(pos_indices) == 0:
        raise ValueError("Insufficient negative or positive points in the narrow band.")
    
    half = num_points // 2
    
    # For reproducibility, you may set a random seed.
    np.random.seed(42)
    
    selected_neg = neg_indices[np.random.choice(len(neg_indices), size=half, replace=(len(neg_indices) < half))]
    selected_pos = pos_indices[np.random.choice(len(pos_indices), size=half, replace=(len(pos_indices) < half))]
    selected_coords = np.concatenate([selected_neg, selected_pos], axis=0)
    sampled_sdf = sdf_np[selected_coords[:, 0], selected_coords[:, 1], selected_coords[:, 2]]
    
    return selected_coords, sampled_sdf

# ===============================
# STEP 4. 3D interactive plotting functions (for Jupyter notebooks)
# ===============================
def plot_full_sdf_scatter(coords, sdf_values):
    """Interactive 3D scatter plot of all sample points with color corresponding to SDF."""
    fig = plt.figure(figsize=(10, 8))
    ax = fig.add_subplot(111, projection='3d')
    sc = ax.scatter(coords[:, 0], coords[:, 1], coords[:, 2],
                    c=sdf_values, cmap='coolwarm', marker='o', s=5, alpha=0.8)
    plt.colorbar(sc, ax=ax, label='SDF Value')
    ax.set_title("3D Scatter Plot of All SDF Samples")
    ax.set_xlabel("X")
    ax.set_ylabel("Y")
    ax.set_zlabel("Z")
    plt.show()

def plot_surface_scatter(coords, sdf_values, threshold=1.0):
    """Interactive 3D scatter plot for points with |SDF| < threshold (i.e. near the surface)."""
    mask = np.abs(sdf_values) < threshold
    surface_coords = coords[mask]
    surface_sdf = sdf_values[mask]
    
    if surface_coords.size == 0:
        print("No points found near the surface with the given threshold.")
        return
    
    fig = plt.figure(figsize=(10, 8))
    ax = fig.add_subplot(111, projection='3d')
    sc = ax.scatter(surface_coords[:, 0], surface_coords[:, 1], surface_coords[:, 2],
                    c=surface_sdf, cmap='coolwarm', marker='o', s=10)
    plt.colorbar(sc, ax=ax, label=f'SDF Value (|SDF| < {threshold})')
    ax.set_title(f"3D Scatter Plot of Surface Samples (|SDF| < {threshold})")
    ax.set_xlabel("X")
    ax.set_ylabel("Y")
    ax.set_zlabel("Z")
    plt.show()

# ===============================
# Main workflow for a synthetic sphere
# ===============================
def main():
    # (a) Generate a sphere segmentation. Adjust parameters as needed.
    shape = (64, 64, 64)
    radius = 20
    sphere_seg = generate_sphere_segmentation(shape=shape, radius=radius)
    print("Generated sphere segmentation with shape:", sphere_seg.shape)
    
    # (b) Convert segmentation mask to a torch tensor and add batch and channel dimensions.
    seg_tensor = torch.from_numpy(sphere_seg).unsqueeze(0).unsqueeze(0)  # shape: [1, 1, H, W, D]
    
    # (c) Compute the SDF.
    sdf_tensor = compute_target_sdf(seg_tensor, clip_threshold=10)
    print("Computed SDF tensor with shape:", sdf_tensor.shape)
    
    # (d) Sample balanced narrow band points.
    num_points = 10000  # change if desired
    band_threshold = 4  # consider points with |SDF| < 4
    coords, sdf_samples = sample_sdf_balanced(seg_tensor, sdf_tensor, num_points=num_points, band_threshold=band_threshold)
    print("Sampled", coords.shape[0], "points from the narrow band.")
    
    # (e) Create interactive scatter plots.
    # Full scatter plot.
    plot_full_sdf_scatter(coords, sdf_samples)
    # Surface scatter plot (only points near zero).
    plot_surface_scatter(coords, sdf_samples, threshold=1.0)

main()
