# 3D Volume Visualization

Visualize 3D volume data (e.g., sensitivity matrix, reconstructed images, etc.)

In [1]:
import torch
import numpy as np
import plotly.graph_objects as go
import plotly.express as px
from pathlib import Path

# Load 3D volume data - supports .pth, .npy, and .npz files
# file_path = "config.iter.1.npy"  # Change this to your file
file_path = "sensitivity/sens_geoparallel.pth"

# Try loading as PyTorch first (works for .pth and misnamed PyTorch files)
try:
    data = torch.load(file_path, weights_only=True)
    data_np = data.cpu().numpy() if torch.is_tensor(data) else np.array(data)
    print(f"Loaded as PyTorch file")
except Exception as e:
    # If PyTorch fails, try NumPy
    try:
        loaded = np.load(file_path)
        if isinstance(loaded, np.lib.npyio.NpzFile):
            keys = list(loaded.keys())
            print(f"Loaded .npz file with arrays: {keys}")
            data_np = loaded[keys[0]]
            print(f"Using array: {keys[0]}")
        else:
            data_np = loaded
        print(f"Loaded as NumPy file")
    except Exception as e2:
        raise ValueError(f"Could not load file as PyTorch or NumPy: {e2}")

print(f"\nLoaded: {file_path}")
print(f"Volume shape: {data_np.shape}")
print(f"Min value: {data_np.min():.2e}, Max value: {data_np.max():.2e}")

Loaded as PyTorch file

Loaded: sensitivity/sens_geoparallel.pth
Volume shape: (1, 37, 37, 11)
Min value: 0.00e+00, Max value: 4.20e+00


## 1. Central Slice Views (XY, XZ, YZ planes)

In [2]:
# Get central slices
# Remove the batch dimension (first dimension)
if data_np.ndim == 4:
    data_np = data_np[0]  # Now shape is [81, 81, 41]

nx, ny, nz = data_np.shape
mid_x, mid_y, mid_z = nx // 2, ny // 2, nz // 2

print(f"Volume dimensions: nx={nx}, ny={ny}, nz={nz}")
print(f"Central indices: mid_x={mid_x}, mid_y={mid_y}, mid_z={mid_z}")

# For masked data: use only non-zero values for colorscale
mask = data_np > 0
if mask.any():
    vmin_global = data_np[mask].min()
    vmax_global = data_np[mask].max()
    print(f"Data range (excluding zeros): [{vmin_global:.3e}, {vmax_global:.3e}]")
else:
    vmin_global = data_np.min()
    vmax_global = data_np.max()
    print(f"Data range (including zeros): [{vmin_global:.3e}, {vmax_global:.3e}]")

# XY plane (looking down at Z)
slice_xy = data_np[:, :, mid_z]
# XZ plane (looking from side Y)
slice_xz = data_np[:, mid_y, :]
# YZ plane (looking from side X)
slice_yz = data_np[mid_x, :, :]

# Create subplots
from plotly.subplots import make_subplots

fig = make_subplots(
    rows=1, cols=3,
    subplot_titles=('XY plane (Z central)', 'XZ plane (Y central)', 'YZ plane (X central)')
)

# Use the non-zero range for colorscale
fig.add_trace(go.Heatmap(z=slice_xy.T, colorscale='Viridis', name='XY', 
                         zmin=vmin_global, zmax=vmax_global), row=1, col=1)
fig.add_trace(go.Heatmap(z=slice_xz.T, colorscale='Viridis', name='XZ',
                         zmin=vmin_global, zmax=vmax_global), row=1, col=2)
fig.add_trace(go.Heatmap(z=slice_yz.T, colorscale='Viridis', name='YZ',
                         zmin=vmin_global, zmax=vmax_global), row=1, col=3)

# Update axes with labels and proper aspect ratio
fig.update_xaxes(title_text="X", scaleanchor="y", scaleratio=1, row=1, col=1)
fig.update_yaxes(title_text="Y", row=1, col=1)

fig.update_xaxes(title_text="X", scaleanchor="y2", scaleratio=1, row=1, col=2)
fig.update_yaxes(title_text="Z", row=1, col=2)

fig.update_xaxes(title_text="Y", scaleanchor="y3", scaleratio=1, row=1, col=3)
fig.update_yaxes(title_text="Z", row=1, col=3)

# Add slider to control vmax (colorscale upper limit)
sliders = [
    dict(
        active=100,  # Start at 100% (full vmax)
        yanchor="top",
        y=-0.15,
        xanchor="left",
        currentvalue=dict(
            prefix="Colorscale max: ",
            visible=True,
            xanchor="right",
            suffix=f" ({vmax_global:.2e} at 100%)"
        ),
        pad=dict(b=10, t=50),
        len=0.9,
        x=0.05,
        steps=[
            dict(
                method="restyle",
                args=[{"zmax": [vmin_global + (i/100) * (vmax_global - vmin_global)] * 3}],
                label=f"{i}%"
            ) for i in range(10, 101, 5)  # 10% to 100% in steps of 5%
        ]
    )
]

fig.update_layout(
    height=600, 
    showlegend=False, 
    title_text="Central Slices (Interactive Colorscale)",
    sliders=sliders
)
fig.show()

print("\nUse the slider to adjust the maximum colorscale value (vmax)")
print("- Lower percentages = compress colorscale to see more detail in low-intensity regions")
print("- Higher percentages = full dynamic range")


Volume dimensions: nx=37, ny=37, nz=11
Central indices: mid_x=18, mid_y=18, mid_z=5
Data range (excluding zeros): [6.800e-02, 4.203e+00]



Use the slider to adjust the maximum colorscale value (vmax)
- Lower percentages = compress colorscale to see more detail in low-intensity regions
- Higher percentages = full dynamic range


## 2. Interactive 3D Volume Visualization

In [12]:
# Create 3D volume plot using Volume trace (better for continuous data)
print(f"Volume stats: min={data_np.min():.2e}, max={data_np.max():.2e}, mean={data_np.mean():.2e}")

# For masked data: use only non-zero values for volume rendering
mask_vol = data_np > 0
if mask_vol.any():
    vmin_data = data_np[mask_vol].min()
    vmax_data = data_np[mask_vol].max()
    print(f"Volume range (non-zero values): [{vmin_data:.3e}, {vmax_data:.3e}]")
else:
    vmin_data = data_np.min()
    vmax_data = data_np.max()

# Create proper 3D coordinate grids
X, Y, Z = np.mgrid[0:nx, 0:ny, 0:nz]

# Interactive Volume rendering with sliders
fig = go.Figure()

fig.add_trace(go.Volume(
    x=X.flatten(),
    y=Y.flatten(),
    z=Z.flatten(),
    value=data_np.flatten(),
    isomin=vmin_data,
    isomax=vmax_data,
    opacity=0.1,  # Will be controlled by slider
    surface_count=21,
    colorscale='Viridis',
    showscale=True,
    colorbar=dict(
        title="Intensity",
        thickness=20,
        len=0.7,
        x=1.02
    ),
    caps=dict(x_show=False, y_show=False, z_show=False)
))

# Create sliders for interactive control
sliders = [
    # Opacity slider
    dict(
        active=10,  # Initial opacity = 0.1
        yanchor="top",
        y=1.15,
        xanchor="left",
        currentvalue=dict(
            prefix="Opacity: ",
            visible=True,
            xanchor="right"
        ),
        pad=dict(t=50),
        len=0.4,
        x=0.05,
        steps=[
            dict(
                method="restyle",
                args=[{"opacity": [i/100]}],
                label=f"{i/100:.2f}"
            ) for i in range(30, 81, 5)  # 0.01 to 0.50 in steps of 0.05
        ]
    ),
    # Threshold slider (more impactful)
    dict(
        active=0,  # Initial threshold = 0% (show all)
        yanchor="top",
        y=1.15,
        xanchor="left",
        currentvalue=dict(
            prefix="Threshold: ",
            visible=True,
            xanchor="right"
        ),
        pad=dict(t=50),
        len=0.4,
        x=0.55,
        steps=[
            dict(
                method="restyle",
                args=[{"isomin": [vmin_data + (i/100) * (vmax_data - vmin_data)]}],
                label=f"{i}%"
            ) for i in range(0, 91, 10)  # 0% to 90% in steps of 10%
        ]
    )
]

fig.update_layout(
    sliders=sliders,
    title='3D Volume Rendering (Interactive)',
    scene=dict(
        xaxis_title='X voxels',
        yaxis_title='Y voxels',
        zaxis_title='Z voxels',
        aspectmode='data',
        camera=dict(
            eye=dict(x=1.5, y=1.5, z=1.5)
        )
    ),
    height=800,
    width=1000
)
fig.show()

print("\nInteractive controls:")
print("- Left slider: Adjust opacity (transparency)")
print("- Right slider: Adjust threshold (hide low-intensity regions)")
print("  * 0% = show everything")
print("  * 50% = only show regions above 50% of max intensity")
print("  * 90% = only show highest intensity regions")

Volume stats: min=0.00e+00, max=6.27e+00, mean=2.70e-01
Volume range (non-zero values): [6.825e-02, 6.266e+00]



Interactive controls:
- Left slider: Adjust opacity (transparency)
- Right slider: Adjust threshold (hide low-intensity regions)
  * 0% = show everything
  * 50% = only show regions above 50% of max intensity
  * 90% = only show highest intensity regions


## 3. Interactive Slice Explorer (with slider)

In [7]:
# Create animated slider through different axes
fig = go.Figure()

# Get global min/max for consistent colorscale (excluding zeros from mask)
mask = data_np > 0
if mask.any():
    vmin = data_np[mask].min()
    vmax = data_np[mask].max()
    print(f"Fixed colorscale range (non-zero values): [{vmin:.3e}, {vmax:.3e}]")
else:
    vmin = data_np.min()
    vmax = data_np.max()
    print(f"Fixed colorscale range: [{vmin:.3e}, {vmax:.3e}]")

# Add all possible slices (Z slices, Y slices, X slices)
trace_index = 0
z_traces = []
y_traces = []
x_traces = []

# Z slices (XY planes)
for i in range(nz):
    fig.add_trace(go.Heatmap(
        z=data_np[:, :, i].T,
        colorscale='Viridis',
        visible=(i == mid_z),  # Middle Z slice visible initially
        name=f'Z={i}',
        zmin=vmin,
        zmax=vmax,
        colorbar=dict(
            title="Intensity",
            thickness=20,
            len=0.7
        )
    ))
    z_traces.append(trace_index)
    trace_index += 1

# Y slices (XZ planes)
for i in range(ny):
    fig.add_trace(go.Heatmap(
        z=data_np[:, i, :].T,
        colorscale='Viridis',
        visible=False,
        name=f'Y={i}',
        zmin=vmin,
        zmax=vmax,
        colorbar=dict(
            title="Intensity",
            thickness=20,
            len=0.7
        )
    ))
    y_traces.append(trace_index)
    trace_index += 1

# X slices (YZ planes)
for i in range(nx):
    fig.add_trace(go.Heatmap(
        z=data_np[i, :, :].T,
        colorscale='Viridis',
        visible=False,
        name=f'X={i}',
        zmin=vmin,
        zmax=vmax,
        colorbar=dict(
            title="Intensity",
            thickness=20,
            len=0.7
        )
    ))
    x_traces.append(trace_index)
    trace_index += 1

# Create slider steps for Z axis
z_steps = []
for i in range(nz):
    visible = [False] * (nz + ny + nx)
    visible[z_traces[i]] = True
    step = dict(
        method="update",
        args=[{"visible": visible}],
        label=str(i)
    )
    z_steps.append(step)

# Create slider steps for Y axis
y_steps = []
for i in range(ny):
    visible = [False] * (nz + ny + nx)
    visible[y_traces[i]] = True
    step = dict(
        method="update",
        args=[{"visible": visible}],
        label=str(i)
    )
    y_steps.append(step)

# Create slider steps for X axis
x_steps = []
for i in range(nx):
    visible = [False] * (nz + ny + nx)
    visible[x_traces[i]] = True
    step = dict(
        method="update",
        args=[{"visible": visible}],
        label=str(i)
    )
    x_steps.append(step)

# Create both sliders
sliders = [
    # Slice position slider (TOP - index 0)
    dict(
        active=mid_z,
        yanchor="top",
        y=0.05,
        xanchor="left",
        currentvalue=dict(
            prefix="Slice index: ",
            visible=True,
            xanchor="right"
        ),
        pad=dict(b=10, t=50),
        len=0.9,
        x=0.1,
        steps=z_steps
    ),
    # Axis selector slider (BOTTOM - index 1)
    dict(
        active=0,
        yanchor="top",
        y=-0.15,
        xanchor="left",
        currentvalue=dict(
            prefix="Slice axis: ",
            visible=True,
            xanchor="right"
        ),
        pad=dict(b=10, t=50),
        len=0.9,
        x=0.1,
        steps=[
            dict(
                method="update",
                args=[
                    {"visible": None},  # Will be set below
                    {"sliders[0].active": mid_z, "sliders[0].steps": z_steps}
                ],
                label="Z (XY planes)",
            ),
            dict(
                method="update",
                args=[
                    {"visible": None},  # Will be set below
                    {"sliders[0].active": mid_y, "sliders[0].steps": y_steps}
                ],
                label="Y (XZ planes)",
            ),
            dict(
                method="update",
                args=[
                    {"visible": None},  # Will be set below
                    {"sliders[0].active": mid_x, "sliders[0].steps": x_steps}
                ],
                label="X (YZ planes)",
            ),
        ]
    )
]

# Manually set the visible traces for the axis selector steps
# Z axis - show middle Z slice
visible_z = [False] * (nz + ny + nx)
visible_z[z_traces[mid_z]] = True
sliders[1]['steps'][0]['args'][0]['visible'] = visible_z

# Y axis - show middle Y slice
visible_y = [False] * (nz + ny + nx)
visible_y[y_traces[mid_y]] = True
sliders[1]['steps'][1]['args'][0]['visible'] = visible_y

# X axis - show middle X slice
visible_x = [False] * (nz + ny + nx)
visible_x[x_traces[mid_x]] = True
sliders[1]['steps'][2]['args'][0]['visible'] = visible_x

fig.update_layout(
    sliders=sliders,
    title="3D Volume Slice Explorer - Choose Axis and Slice",
    height=700,
    xaxis_title="Voxels",
    yaxis_title="Voxels",
    xaxis=dict(scaleanchor="y", scaleratio=1),
)

fig.show()

print("\nInstructions:")
print("- BOTTOM slider: Choose which axis to slice through (Z/Y/X)")
print("- TOP slider: Choose the slice position along that axis")
print("\nNote: When you change the axis (bottom slider), it will show the middle slice of that axis")

Fixed colorscale range: [nan, nan]



Instructions:
- BOTTOM slider: Choose which axis to slice through (Z/Y/X)
- TOP slider: Choose the slice position along that axis

Note: When you change the axis (bottom slider), it will show the middle slice of that axis
