# SPATIOTEMPORAL CLUSTERING OF VOXELS

Find clusters of voxels that are correlated in both space and time using:
1. Filter voxels within the mask with sufficient signal
2. Apply agglomerative clustering with spatial connectivity constraints
3. Visualize cluster locations and their mean intensity over time

In [None]:
from sklearn.cluster import AgglomerativeClustering
from sklearn.preprocessing import StandardScaler
from scipy.ndimage import gaussian_filter1d
from scipy.spatial.distance import pdist, squareform
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D

# Parameters for clustering
N_CLUSTERS = 8  # number of clusters to find
MIN_SIGNAL_THRESHOLD = 0.1  # minimum mean signal to include voxel
TEMPORAL_SMOOTHING_SIGMA = 2  # gaussian smoothing for time series

# Get dimensions
n_frames, n_z, n_y, n_x = g5_masked.shape

# Create coordinate grids for spatial information
z_coords, y_coords, x_coords = np.meshgrid(
    np.arange(n_z), 
    np.arange(n_y), 
    np.arange(n_x), 
    indexing='ij'
)

# Flatten everything
voxel_timeseries = g5_masked.reshape(n_frames, -1).T  # shape: (n_voxels, n_frames)
z_flat = z_coords.flatten()
y_flat = y_coords.flatten()
x_flat = x_coords.flatten()

# Filter voxels: must be in mask and have sufficient signal
mask_3d = np.tile(mask_binned[np.newaxis, :, :], (n_z, 1, 1))
mask_flat = mask_3d.flatten()
mean_signal = voxel_timeseries.mean(axis=1)

valid_voxels = (mask_flat > 0) & (mean_signal > MIN_SIGNAL_THRESHOLD)
print(f"Total voxels: {len(valid_voxels)}")
print(f"Valid voxels (in mask with signal): {valid_voxels.sum()}")

Total voxels: 243750
Valid voxels (in mask with signal): 100677


In [None]:
# Extract valid voxel data
valid_timeseries = voxel_timeseries[valid_voxels]  # shape: (n_valid, n_frames)
valid_z = z_flat[valid_voxels]
valid_y = y_flat[valid_voxels]
valid_x = x_flat[valid_voxels]

# Smooth the time series
valid_timeseries_smooth = gaussian_filter1d(valid_timeseries, sigma=TEMPORAL_SMOOTHING_SIGMA, axis=1)

# Z-score normalize each voxel's time series for correlation-based clustering
scaler = StandardScaler()
valid_timeseries_norm = scaler.fit_transform(valid_timeseries_smooth.T).T

# Combine temporal features with spatial coordinates for clustering
# Weight spatial coordinates to encourage spatial contiguity
SPATIAL_WEIGHT = 0.5  # adjust to balance temporal vs spatial similarity

# Normalize spatial coordinates
spatial_coords = np.column_stack([valid_z, valid_y, valid_x])
spatial_coords_norm = (spatial_coords - spatial_coords.mean(axis=0)) / (spatial_coords.std(axis=0) + 1e-6)

# Create combined feature matrix: temporal + weighted spatial
combined_features = np.hstack([
    valid_timeseries_norm, 
    spatial_coords_norm * SPATIAL_WEIGHT * np.sqrt(n_frames)  # scale spatial to match temporal dimensions
])

print(f"Feature matrix shape: {combined_features.shape}")
print(f"  - Temporal features: {valid_timeseries_norm.shape[1]}")
print(f"  - Spatial features: 3 (z, y, x)")

Feature matrix shape: (100677, 654)
  - Temporal features: 651
  - Spatial features: 3 (z, y, x)


In [None]:
# Perform agglomerative clustering
print(f"Clustering {valid_voxels.sum()} voxels into {N_CLUSTERS} clusters...")

clustering = AgglomerativeClustering(
    n_clusters=N_CLUSTERS,
    metric='euclidean',
    linkage='ward'
)

cluster_labels = clustering.fit_predict(combined_features)
print(f"Clustering complete!")

# Count voxels per cluster
unique, counts = np.unique(cluster_labels, return_counts=True)
for c, n in zip(unique, counts):
    print(f"  Cluster {c}: {n} voxels")

Clustering 100677 voxels into 8 clusters...
Clustering complete!
  Cluster 0: 13682 voxels
  Cluster 1: 17648 voxels
  Cluster 2: 15907 voxels
  Cluster 3: 17280 voxels
  Cluster 4: 12222 voxels
  Cluster 5: 11425 voxels
  Cluster 6: 6590 voxels
  Cluster 7: 5923 voxels


In [None]:
# Compute mean intensity over time for each cluster
cluster_means = np.zeros((n_frames, N_CLUSTERS))
cluster_std = np.zeros((n_frames, N_CLUSTERS))

for c in range(N_CLUSTERS):
    cluster_mask = cluster_labels == c
    cluster_timeseries = valid_timeseries[cluster_mask]  # original (non-normalized) values
    cluster_means[:, c] = cluster_timeseries.mean(axis=0)
    cluster_std[:, c] = cluster_timeseries.std(axis=0)

# Normalize to F/F0 (baseline is first 60 frames)
baseline = cluster_means[:60, :].mean(axis=0, keepdims=True)
cluster_means_norm = cluster_means / (baseline + 1e-6)
cluster_std_norm = cluster_std / (baseline + 1e-6)

print("Cluster mean time series computed!")

Cluster mean time series computed!


### Visualize Cluster Intensity Over Time

In [None]:
cluster_means_norm

array([[6.68364736e-04, 9.59661277e-01, 2.08397723e-04, ...,
        9.89731024e-01, 9.00330419e-01, 1.03552052e+00],
       [6.85703443e-04, 9.55019647e-01, 2.23052491e-04, ...,
        9.85920526e-01, 8.04739513e-01, 1.02458439e+00],
       [6.98554330e-04, 1.01011356e+00, 5.64026809e-02, ...,
        9.90533424e-01, 8.36017073e-01, 1.03250555e+00],
       ...,
       [6.84402838e-03, 1.06273594e+00, 4.65215844e-04, ...,
        1.08438351e+00, 1.05852346e+00, 1.19439862e+00],
       [1.08269779e-03, 1.08775175e+00, 3.68597005e-04, ...,
        1.08162569e+00, 1.02092260e+00, 1.19262235e+00],
       [9.72647941e-01, 1.08214655e+00, 1.39519086e+00, ...,
        1.07703696e+00, 1.00121147e+00, 1.13448674e+00]], shape=(651, 8))

In [None]:
np.arange(N_CLUSTERS + 1).shape

(9,)

In [None]:
# Plot mean intensity of each cluster over time
%matplotlib qt

# Create time vector (assuming same t as defined earlier, or create new one)
try:
    time_vec = t
except NameError:
    time_vec = np.arange(n_frames)

# Use a colormap for clusters
cmap = plt.cm.get_cmap('tab10', N_CLUSTERS)

fig, axes = plt.subplots(2, 1, figsize=(14, 10), sharex=True)

# Top plot: All clusters overlaid
ax1 = axes[0]
for c in range(N_CLUSTERS):
    color = cmap(c)
    ax1.plot(time_vec, cluster_means_norm[:, c], label=f'Cluster {c}', color=color, lw=2, alpha=0.8)
ax1.axhline(1, ls='--', c='k', alpha=0.5, zorder=0)
ax1.set_ylabel(r'$F/F_{\mathrm{baseline}}$')
ax1.set_title('Mean Intensity of Each Cluster Over Time')
ax1.legend(loc='center left', bbox_to_anchor=(1, 0.5), fontsize=10)
ax1.set_xlim(time_vec[0], time_vec[-1])

# Bottom plot: Heatmap of cluster time series
ax2 = axes[1]
im = ax2.pcolormesh(time_vec, np.arange(N_CLUSTERS), cluster_means_norm.T, cmap='plasma', shading='auto')
ax2.set_ylabel('Cluster')
ax2.set_xlabel('Time')
ax2.set_yticks(np.arange(N_CLUSTERS) + 0.5)
ax2.set_yticklabels([f'{c}' for c in range(N_CLUSTERS)])
plt.colorbar(im, ax=ax2, label=r'$F/F_{\mathrm{baseline}}$')
ax2.set_title('Cluster Intensity Heatmap')

plt.tight_layout()
plt.show()

  cmap = plt.cm.get_cmap('tab10', N_CLUSTERS)


### Visualize Cluster Spatial Locations

In [None]:
# Create a 3D cluster label volume for visualization
cluster_volume = np.full((n_z, n_y, n_x), np.nan)

# Assign cluster labels to valid voxel positions
valid_indices = np.where(valid_voxels)[0]
for i, (z, y, x, label) in enumerate(zip(valid_z, valid_y, valid_x, cluster_labels)):
    cluster_volume[z, y, x] = label

# Plot 2D slices showing cluster locations at different z-levels
n_z_plots = min(6, n_z)
z_indices = np.linspace(0, n_z - 1, n_z_plots, dtype=int)

fig, axes = plt.subplots(2, 3, figsize=(15, 10))
axes = axes.flatten()

for i, z_idx in enumerate(z_indices):
    ax = axes[i]
    slice_data = cluster_volume[z_idx, :, :]
    
    # Create masked array for proper visualization
    masked_slice = np.ma.masked_invalid(slice_data)
    
    im = ax.imshow(masked_slice, cmap='tab10', vmin=0, vmax=N_CLUSTERS-1, interpolation='nearest')
    ax.set_title(f'Z-slice {z_idx}')
    ax.set_xlabel('X')
    ax.set_ylabel('Y')
    ax.axis('equal')

# Add colorbar
cbar_ax = fig.add_axes([0.92, 0.15, 0.02, 0.7])
cbar = fig.colorbar(im, cax=cbar_ax, ticks=np.arange(N_CLUSTERS))
cbar.set_label('Cluster')

plt.suptitle('Cluster Spatial Distribution Across Z-slices', fontsize=14, y=1.02)
plt.tight_layout(rect=[0, 0, 0.9, 1])
plt.show()

  plt.tight_layout(rect=[0, 0, 0.9, 1])


In [None]:
# 3D scatter plot of cluster locations
fig = plt.figure(figsize=(12, 10))
ax = fig.add_subplot(111, projection='3d')

# Subsample if too many points for visualization
max_points = 5000
if len(valid_z) > max_points:
    subsample_idx = np.random.choice(len(valid_z), max_points, replace=False)
else:
    subsample_idx = np.arange(len(valid_z))

scatter = ax.scatter(
    valid_x[subsample_idx], 
    valid_y[subsample_idx], 
    valid_z[subsample_idx],
    c=cluster_labels[subsample_idx], 
    cmap='tab10', 
    vmin=0, 
    vmax=N_CLUSTERS-1,
    s=10, 
    alpha=0.6
)

ax.set_xlabel('X')
ax.set_ylabel('Y')
ax.set_zlabel('Z')
ax.set_title('3D Cluster Distribution')

cbar = plt.colorbar(scatter, ax=ax, shrink=0.6, pad=0.1)
cbar.set_label('Cluster')
cbar.set_ticks(np.arange(N_CLUSTERS))

plt.tight_layout()
plt.show()

### Individual Cluster Plots with Spatial Location + Time Series

In [None]:
# Create individual plots for each cluster showing spatial location and time series
n_cols = 4
n_rows = int(np.ceil(N_CLUSTERS / n_cols))

fig, axes = plt.subplots(n_rows, n_cols, figsize=(16, 4 * n_rows))
axes = axes.flatten()

# Get MIP of one frame for background
background = g5_masked[0].max(axis=0)  # MIP across z

for c in range(N_CLUSTERS):
    ax = axes[c]
    color = cmap(c)
    
    # Create binary mask for this cluster (MIP across z)
    cluster_mask_3d = cluster_volume == c
    cluster_mask_mip = np.nanmax(cluster_mask_3d.astype(float), axis=0)
    cluster_mask_mip[cluster_mask_mip == 0] = np.nan
    
    # Plot background with cluster overlay
    ax.imshow(background, cmap='gray', alpha=0.7)
    ax.imshow(cluster_mask_mip, cmap='tab10', vmin=0, vmax=N_CLUSTERS-1, alpha=0.6)
    ax.set_title(f'Cluster {c} ({(cluster_labels == c).sum()} voxels)', color=color, fontweight='bold')
    ax.axis('off')

# Hide unused subplots
for i in range(N_CLUSTERS, len(axes)):
    axes[i].axis('off')

plt.suptitle('Cluster Spatial Locations (MIP)', fontsize=14)
plt.tight_layout()
plt.show()

# Plot individual time series for each cluster
fig, axes = plt.subplots(n_rows, n_cols, figsize=(16, 3 * n_rows), sharex=True, sharey=True)
axes = axes.flatten()

for c in range(N_CLUSTERS):
    ax = axes[c]
    color = cmap(c)
    
    # Plot mean with shaded std
    mean_trace = cluster_means_norm[:, c]
    std_trace = cluster_std_norm[:, c]
    
    ax.fill_between(time_vec, mean_trace - std_trace, mean_trace + std_trace, 
                    color=color, alpha=0.3)
    ax.plot(time_vec, mean_trace, color=color, lw=2)
    ax.axhline(1, ls='--', c='k', alpha=0.3)
    ax.set_title(f'Cluster {c}', color=color, fontweight='bold')
    ax.set_xlim(time_vec[0], time_vec[-1])

# Add common labels
for ax in axes[-n_cols:]:
    ax.set_xlabel('Time')
for ax in axes[::n_cols]:
    ax.set_ylabel(r'$F/F_{\mathrm{baseline}}$')

# Hide unused subplots
for i in range(N_CLUSTERS, len(axes)):
    axes[i].axis('off')

plt.suptitle('Cluster Mean Intensity Over Time (Â±1 std)', fontsize=14)
plt.tight_layout()
plt.show()

### Cluster Correlation Matrix

In [None]:
# Compute and plot correlation matrix between clusters
cluster_corr = np.corrcoef(cluster_means_norm.T)

fig, ax = plt.subplots(figsize=(8, 7))
im = ax.imshow(cluster_corr, cmap='RdBu_r', vmin=-1, vmax=1)
ax.set_xticks(np.arange(N_CLUSTERS))
ax.set_yticks(np.arange(N_CLUSTERS))
ax.set_xticklabels([f'{c}' for c in range(N_CLUSTERS)])
ax.set_yticklabels([f'{c}' for c in range(N_CLUSTERS)])
ax.set_xlabel('Cluster')
ax.set_ylabel('Cluster')
ax.set_title('Cluster-Cluster Correlation Matrix')

# Add correlation values as text
for i in range(N_CLUSTERS):
    for j in range(N_CLUSTERS):
        text_color = 'white' if abs(cluster_corr[i, j]) > 0.5 else 'black'
        ax.text(j, i, f'{cluster_corr[i, j]:.2f}', ha='center', va='center', 
                color=text_color, fontsize=9)

plt.colorbar(im, ax=ax, label='Correlation')
plt.tight_layout()
plt.show()

### Save Cluster Results

In [None]:
PTH

'D:\\DATA\\g5ht-free\\20251223\\date-20251223_strain-ISg5HT_condition-starvedpatch_worm005'

In [None]:
# Save clustering results
import tifffile

# Save cluster volume as tif
cluster_volume_save = np.nan_to_num(cluster_volume, nan=-1).astype(np.int16)
tifffile.imwrite(os.path.join(PTH, 'cluster_labels_volume.tif'), cluster_volume_save)

# Save cluster mean time series
cluster_results = {
    'cluster_means': cluster_means,
    'cluster_means_normalized': cluster_means_norm,
    'cluster_std': cluster_std,
    'n_clusters': N_CLUSTERS,
    'valid_voxel_coords': np.column_stack([valid_z, valid_y, valid_x]),
    'cluster_labels': cluster_labels,
    'time_vector': time_vec,
    'cluster_correlation': cluster_corr
}
np.save(os.path.join(PTH, 'cluster_results.npy'), cluster_results)

# Save cluster means as CSV for easy loading
import pandas as pd
df_clusters = pd.DataFrame(cluster_means_norm, columns=[f'cluster_{c}' for c in range(N_CLUSTERS)])
df_clusters.insert(0, 'time', time_vec)
df_clusters.to_csv(os.path.join(PTH, 'cluster_timeseries.csv'), index=False)

print(f"Saved clustering results to {PTH}")
print(f"  - cluster_labels_volume.tif: 3D volume with cluster labels")
print(f"  - cluster_results.npy: Full results dictionary")
print(f"  - cluster_timeseries.csv: Cluster mean time series")

Saved clustering results to D:\DATA\g5ht-free\20251223\date-20251223_strain-ISg5HT_condition-starvedpatch_worm005
  - cluster_labels_volume.tif: 3D volume with cluster labels
  - cluster_results.npy: Full results dictionary
  - cluster_timeseries.csv: Cluster mean time series
