In [1]:
import numpy as np
train_seismic = np.load('data/train/train_seismic.npy')  # shape (I, X, D)
train_labels = np.load('data/train/train_labels.npy')    # same shape

In [2]:
import matplotlib.pyplot as plt
import plotly.graph_objs as go
import ipywidgets as widgets
from IPython.display import display

# Duplicate seismic data for cropping
seismic_all = train_seismic
labels_all = train_labels

# Setup colormap 
cmap = plt.get_cmap('tab20')
n_classes = labels_all.max() + 1
colors_lookup = (cmap(np.linspace(0, 1, n_classes))[:, :3] * 255).astype(np.uint8)

# Volume dimensions
X_MAX, Y_MAX, Z_MAX = seismic_all.shape

# Interactive plot function
def plot_cropped(threshold, x_range, y_range, z_range, sample_every):
    x_start, x_end = x_range
    y_start, y_end = y_range
    z_start, z_end = z_range

    # Crop subvolume
    seismic = seismic_all[x_start:x_end, y_start:y_end, z_start:z_end]
    labels = labels_all[x_start:x_end, y_start:y_end, z_start:z_end]

    # Threshold mask
    mask = seismic > threshold
    points = np.argwhere(mask)
    
    if len(points) == 0:
        print("No points exceed threshold in selected region.")
        return

    labs = labels[mask]

    # === Sampling ===
    if sample_every > 1:
        sampled_idx = np.arange(0, len(points), sample_every)
        points = points[sampled_idx]
        labs = labs[sampled_idx]

    # Convert labels to RGB color
    colors = ['rgb({}, {}, {})'.format(*colors_lookup[label]) for label in labs]

    # Restore to global coordinates
    scatter = go.Scatter3d(
        x=points[:, 0] + x_start,
        y=points[:, 1] + y_start,
        z=points[:, 2] + z_start,
        mode='markers',
        marker=dict(size=2, color=colors, opacity=0.8)
    )

    layout = go.Layout(
        title='Seismic Point Cloud Visualization',
        scene=dict(
            xaxis_title='Inline',
            yaxis_title='Crossline',
            zaxis_title='Depth',
            zaxis=dict(autorange='reversed')
            ),
        margin=dict(l=0, r=0, b=0, t=50)
    )

    fig = go.Figure(data=[scatter], layout=layout)
    fig.show()

# === Widgets ===
threshold_slider = widgets.FloatSlider(value=0.3, min=0.0, max=1.0, step=0.05, description='Threshold:')
x_slider = widgets.IntRangeSlider(value=[0, 150], min=0, max=X_MAX, step=1, description='Inline:')
y_slider = widgets.IntRangeSlider(value=[450, 650], min=0, max=Y_MAX, step=1, description='Crossline:')
z_slider = widgets.IntRangeSlider(value=[0, Z_MAX], min=0, max=Z_MAX, step=1, description='Depth:')
sample_slider = widgets.IntSlider(value=5, min=1, max=20, step=1, description='Sample Rate:')

# === Combine UI ===
ui = widgets.VBox([threshold_slider, x_slider, y_slider, z_slider, sample_slider])
out = widgets.interactive_output(
    plot_cropped,
    {
        'threshold': threshold_slider,
        'x_range': x_slider,
        'y_range': y_slider,
        'z_range': z_slider,
        'sample_every': sample_slider
    }
)

# Display
display(ui, out)

VBox(children=(FloatSlider(value=0.3, description='Threshold:', max=1.0, step=0.05), IntRangeSlider(value=(0, …

Output()