In [1]:
import h5py
import numpy as np
import matplotlib.pyplot as plt
from scipy import signal
from scipy.signal import butter, filtfilt, decimate
from scipy.fft import fft, fft2, fftshift
import warnings
import time

warnings.filterwarnings('ignore')

%matplotlib tk

In [2]:
'''
Parameter Setup
'''
vmines, vmaxes = -1, 1
vminkf, vmaxfk = 20, 80 # Adjusting the colormap scale

In [3]:
'''
============================================================================
Read Dataset
============================================================================
'''
def read_h5_file(filePath):
    """Reads HDF5 file and extracts metadata and data."""
    print('Reading file metadata...')

    try:
        with h5py.File(filePath, 'r') as f:
            # Read metadata
            fs = f['/Acquisition'].attrs['PulseRate']
            dx = f['/Acquisition'].attrs['SpatialSamplingInterval']
            num_locs = f['/Acquisition'].attrs['NumberOfLoci']

            # Read main data
            X = f['/Acquisition/Raw[0]/RawData'][:]
            X = X.astype(np.float64)

        print('Metadata read successfully:')
        print(f' -> fs = {fs:.1f} Hz')
        print(f' -> dx = {dx:.3f} m')
        print(f' -> num_locs = {num_locs} channels')

        return X, fs, dx, num_locs

    except Exception as e:
        raise RuntimeError(f'Error reading the file: {str(e)}')

In [4]:
'''
============================================================================
TEMPORAL TRIMMING
============================================================================
'''
def temporal_cut(X, t, t_start_cut=5.0, t_end_cut=35.0):
    """Applies temporal trimming to the data."""
    print('Applying temporal cut...')

    t_original_end = t[-1]
    num_tr_original = X.shape[1]

    # Find indices
    idx_start_mask = t >= t_start_cut
    idx_end_mask = t <= t_end_cut

    if not np.any(idx_start_mask):
        raise ValueError(f'No times >= {t_start_cut}s found. Max time: {t_original_end:.2f}s')
    if not np.any(idx_end_mask):
        raise ValueError(f'No times <= {t_end_cut}s found. Max time: {t_original_end:.2f}s')

    idx_start = np.where(idx_start_mask)[0][0]
    idx_end = np.where(idx_end_mask)[0][-1]

    if idx_end <= idx_start:
        raise ValueError(f'Invalid time interval: idx_end({idx_end}) <= idx_start({idx_start})')

    # Apply cut
    X_cut = X[:, idx_start:idx_end + 1]
    t_cut = t[idx_start:idx_end + 1]

    print('Temporal cut completed.')
    print(f' -> Original time: {t[0]:.2f} s to {t_original_end:.2f} s ({num_tr_original} samples)')
    print(f' -> New time: {t_cut[0]:.2f} s to {t_cut[-1]:.2f} s ({X_cut.shape[1]} samples)')

    return X_cut, t_cut


In [5]:
'''
============================================================================
BAND-PASS FILTER
============================================================================
'''

def bandpass_filter(X, fs, hp_cut=0.1, lp_cut=9000, order=4):
    """Applies a band-pass filter to each channel."""
    print(f'Applying band-pass filter from {hp_cut} to {lp_cut} Hz...')

    nyq = 0.5 * fs
    low = lp_cut / nyq
    high = hp_cut / nyq

    if high >= 1 or low <= 0:
        raise ValueError(f"Cutoff frequencies out of range. Nyquist: {nyq} Hz")

    if high > low:
        high, low = low, high

    b, a = butter(order, [high, low], btype='band')

    X_filtered = np.zeros_like(X)
    for k in range(X.shape[0]):
        X_filtered[k, :] = filtfilt(b, a, X[k, :])

    print('Filtering completed.')
    return X_filtered

In [6]:
'''
============================================================================
SPATIAL GROUPING (BINNING)
============================================================================
'''

def spatial_binning(X, y, N_group=3):
    """Spatially groups channels."""
    print('Starting spatial channel binning...')

    num_locs_real, num_tr = X.shape
    n_keep = (num_locs_real // N_group) * N_group

    if n_keep < N_group:
        raise ValueError(f'N_group={N_group} is too large for num_locs={num_locs_real}')

    if n_keep < num_locs_real:
        print(f'Discarding last {num_locs_real - n_keep} channels')

    X_trim = X[:n_keep, :]
    y_trim = y[:n_keep]

    n_groups = n_keep // N_group
    X_binned = X_trim.reshape(N_group, n_groups, num_tr).mean(axis=0)
    y_binned = y_trim.reshape(N_group, n_groups).mean(axis=0)

    print('Binning completed!')
    print(f' -> Binning factor: {N_group}')
    print(f' -> New dimensions: {X_binned.shape[0]} channels × {num_tr} samples')

    return X_binned, y_binned

In [7]:
'''
============================================================================
FOURIER AND CORRELATION ANALYSIS
============================================================================
'''

def fourier_analysis(signal_data, fs):
    """Computes FFT of a signal."""
    L = len(signal_data)
    f = np.fft.rfftfreq(L, 1 / fs)
    Y = fft(signal_data)
    P2 = np.abs(Y / L)
    P1 = P2[:L // 2 + 1]
    P1[1:-1] = 2 * P1[1:-1]
    return f, P1

def cross_correlation(signal1, signal2, fs):
    """Computes cross-correlation and time delay."""
    c = np.correlate(signal2, signal1, mode='full')
    lags = np.arange(-len(signal1) + 1, len(signal2))
    max_corr_idx = np.argmax(c)
    time_delay = lags[max_corr_idx] / fs
    return c, lags / fs, time_delay

In [8]:
'''
============================================================================
SPECTROGRAM WITH INTEGRATED BAND
============================================================================
'''

def integrated_band_spectrogram(X, fs, freq_band, window_time=0.5, overlap_perc=0.75):
    """Computes a temporal PSD map integrated over a frequency band."""
    print('Computing spectrogram with integrated band...')

    n_window = int(window_time * fs)
    n_overlap = int(overlap_perc * n_window)

    _, _, t_spec = signal.spectrogram(X[0, :], fs=fs,
                                       window='hamming',
                                       nperseg=n_window,
                                       noverlap=n_overlap)
    num_time_windows = len(t_spec)
    num_locs = X.shape[0]

    PSD_map_temporal = np.zeros((num_locs, num_time_windows))

    t0 = time.time()
    for k in range(num_locs):
        f_spec, t_spec, Sxx = signal.spectrogram(X[k, :], fs=fs,
                                                  window='hamming',
                                                  nperseg=n_window,
                                                  noverlap=n_overlap)

        psd = np.abs(Sxx) ** 2 / (fs * np.sum(np.hamming(n_window) ** 2))

        idx_band = np.where((f_spec >= freq_band[0]) & (f_spec <= freq_band[1]))[0]
        if len(idx_band) > 0:
            band_power = np.trapz(psd[idx_band, :], f_spec[idx_band], axis=0)
            PSD_map_temporal[k, :] = band_power

    print(f'Computation completed in {time.time() - t0:.2f} seconds')
    return PSD_map_temporal, t_spec

In [9]:
'''
============================================================================
2D FFT (K-F ANALYSIS)
============================================================================
'''

def fft_2d_analysis(X_roi, fs, dx, scale_factor=0.02292):
    """Performs 2D FFT (k-f) analysis on a region of interest."""
    Ny, Nt = X_roi.shape

    win_y = np.hanning(Ny)
    win_t = np.hanning(Nt)
    win_2d = np.outer(win_y, win_t)
    X_win = X_roi * win_2d

    F = fftshift(fft2(X_win))
    F_dB = 20 * np.log10(np.abs(F) + 1e-10)

    f_vec = np.fft.fftshift(np.fft.fftfreq(Nt, 1 / fs))
    dx_scaled = dx * scale_factor
    k_vec_scaled = np.fft.fftshift(np.fft.fftfreq(Ny, dx_scaled))

    return F_dB, f_vec, k_vec_scaled

In [10]:
'''
============================================================================
MAIN
============================================================================
'''
# Parameter configuration
filePath = 'onyxAcquisition_2025-11-07_19.12.59_UTC_000060.h5'
# Processing parameters
t_start_cut = 5.0
t_end_cut = 35.0
hp_cut = 0.1
lp_cut = 9000
scale_factor = 0.02292
# Visualization parameters
startFiber = 30
endFiber = 200
startFiberProfile = 120
endFiberProfile = 150
phase_min = -2
phase_max = 2

In [11]:
'''
========================================================================
1. DATA READING
========================================================================
'''
try:
    X, fs, dx, num_locs = read_h5_file(filePath)
except Exception as e:
    print(f"Error reading file: {e}")

# Create time and distance vectors
num_tr = X.shape[1]
t = np.arange(num_tr) / fs
y = np.arange(num_locs) * dx

print(f'Matrix dimensions: {num_locs} channels × {num_tr} samples ({t[-1]:.2f} seconds)')

Reading file metadata...
Metadata read successfully:
 -> fs = 20000.0 Hz
 -> dx = 0.532 m
 -> num_locs = 556 channels
Matrix dimensions: 556 channels × 556 samples (0.03 seconds)


In [12]:
'''
========================================================================
2. TIME CLIPPING
========================================================================
'''
try:
    X, t = temporal_cut(X, t, t_start_cut, t_end_cut)
    num_locs, num_tr = X.shape
except ValueError as e:
    print(f"Temporal cut error: {e}")
    print("Using full data range")

Applying temporal cut...
Temporal cut error: No times >= 5.0s found. Max time: 0.03s
Using full data range


In [13]:
'''
========================================================================
3. PRE-PROCESSING
========================================================================
'''
print('Removing DC mean from each channel...')
X = X - np.mean(X, axis=1, keepdims=True)

Removing DC mean from each channel...


In [14]:
'''
========================================================================
4. BAND-PASS FILTERING
========================================================================
'''
try:
    X = bandpass_filter(X, fs, hp_cut, lp_cut)
except ValueError as e:
    print(f"Band-pass filter error: {e}")
    print("Skipping band-pass filtering")

print('\nProcessing completed!')

Applying band-pass filter from 0.1 to 9000 Hz...
Filtering completed.

Processing completed!


In [15]:
'''
========================================================================
5. INITIAL VISUALIZATION (Figure 1)
========================================================================
'''
print('Generating visualizations...')

# Find channels for visualization
chan_view = np.argmin(np.abs(y - startFiberProfile))
chan_view2 = np.argmin(np.abs(y - endFiberProfile))

print(f' -> Channel 1 ({startFiberProfile} m) selected: #{chan_view} (Real position: {y[chan_view]:.2f} m)')
print(f' -> Channel 2 ({endFiberProfile} m) selected: #{chan_view2} (Real position: {y[chan_view2]:.2f} m)')

# Figure 1: Time signals and FFT
plt.figure(figsize=(15, 10))

# Subplot 1: Time signals
nshow = min(num_tr, int(10 * fs))  # Show first 10 seconds
ax1 = plt.subplot(2, 1, 1)
ax1.plot(t[:nshow], X[chan_view, :nshow], label=f'Channel #{chan_view} ({y[chan_view]:.1f} m)')
ax1.plot(t[:nshow], X[chan_view2, :nshow], label=f'Channel #{chan_view2} ({y[chan_view2]:.1f} m)')
ax1.set_xlabel('Time (s)')
ax1.set_ylabel('Phase (rad)')
ax1.set_title(f'Temporal Signals - Channel {chan_view} vs {chan_view2}')
ax1.legend()
ax1.grid(True)
ax1.set_xlim([t[0], min(t[0] + 5, t[-1])])  # Show first 5 seconds or less
ax1.set_ylim([phase_min, phase_max])

# Subplot 2: FFT
ax2 = plt.subplot(2, 1, 2)

# Calculate FFT for both channels
f1, P1_1 = fourier_analysis(X[chan_view, :], fs)
f2, P1_2 = fourier_analysis(X[chan_view2, :], fs)

# Smooth FFT
from scipy.ndimage import gaussian_filter1d
P1_1_smooth = gaussian_filter1d(P1_1, 10)
P1_2_smooth = gaussian_filter1d(P1_2, 10)

ax2.plot(f1, 10*np.log10(P1_1_smooth/np.max(P1_1_smooth)), label=f'Channel #{chan_view}')
ax2.plot(f2, 10*np.log10(P1_2_smooth/np.max(P1_2_smooth)), label=f'Channel #{chan_view2}')
ax2.set_xlabel('Frequency (Hz)')
ax2.set_ylabel('10*log10(P/max(P)) (dB)')
ax2.set_title('Frequency Spectrum')
ax2.legend()
ax2.grid(True)
ax2.set_xlim([0.1, min(65, fs/2)])  # Limit to Nyquist

plt.tight_layout()
plt.savefig('figure1_signals_and_fft.png', dpi=300)
plt.show()

Generating visualizations...
 -> Channel 1 (120 m) selected: #226 (Real position: 120.17 m)
 -> Channel 2 (150 m) selected: #282 (Real position: 149.95 m)


In [16]:
'''
========================================================================
6. CROSS-CORRELATION AND VELOCITY
========================================================================
'''
print('Calculating cross-correlation...')

signal1 = X[chan_view, :]
signal2 = X[chan_view2, :]

c, lags, time_delay = cross_correlation(signal1, signal2, fs)

# Calculate velocity
distance_m = abs(y[chan_view2] - y[chan_view]) * scale_factor
velocity = abs(distance_m / (time_delay + 1e-10))

print('\n--- Cross-Correlation Results ---')
print(f' -> Distance between channels: {distance_m:.2f} m')
print(f' -> Time delay: {time_delay:.4f} s')
print(f' -> Apparent velocity: {velocity:.2f} m/s')

Calculating cross-correlation...

--- Cross-Correlation Results ---
 -> Distance between channels: 0.68 m
 -> Time delay: 0.0000 s
 -> Apparent velocity: 6825066833.36 m/s


In [17]:
'''
========================================================================
7. TEMPORAL PSD MAP (Figure 2)
========================================================================
'''
print('Generating temporal PSD map...')

freq_band = [hp_cut, min(lp_cut, fs/2 - 1)]  # Ensure it doesn't exceed Nyquist
try:
    PSD_map, t_spec = integrated_band_spectrogram(X, fs, freq_band)

    plt.figure(figsize=(15, 8))

    # Subplot 1: Phase map
    ax1 = plt.subplot(2, 1, 1)
    im1 = ax1.imshow(X[:, :nshow].T, aspect='auto',
                     extent=[y[0], y[-1], t[nshow-1], t[0]],
                     cmap='jet', vmin=phase_min, vmax=phase_max)
    ax1.set_xlabel('Distance (m)')
    ax1.set_ylabel('Time (s)')
    ax1.set_title(f'Band-pass filter {hp_cut}-{lp_cut} Hz | Scale factor: {scale_factor}')
    plt.colorbar(im1, ax=ax1, label='Phase (rad)')
    ax1.set_xlim([y[0], min(630, y[-1])])
    ax1.set_ylim([max(t[0], 5), min(10, t[-1])])

    # Subplot 2: PSD map
    ax2 = plt.subplot(2, 1, 2)
    PSD_dB = 10 * np.log10(PSD_map + 1e-10)
    im2 = ax2.imshow(PSD_dB.T, aspect='auto',
                     extent=[y[0], y[-1], t_spec[-1], t_spec[0]],
                     cmap='viridis')
    ax2.set_xlabel('Distance (m)')
    ax2.set_ylabel('Time (s)')
    ax2.set_title(f'PSD {freq_band[0]}-{freq_band[1]} Hz')
    plt.colorbar(im2, ax=ax2, label='Band Intensity (dB)')
    ax2.set_xlim([y[0], min(630, y[-1])])

    plt.tight_layout()
    plt.savefig('figure2_psd_map.png', dpi=300)
    plt.show()
except Exception as e:
    print(f"Error generating PSD map: {e}")

Generating temporal PSD map...
Computing spectrogram with integrated band...
Error generating PSD map: noverlap must be less than nperseg.


In [18]:
'''
========================================================================
8. REGION OF INTEREST WITH SCALE (Figure 3)
========================================================================
'''
print('Processing region of interest...')

# Debug: Show range information
print(f"Range of y: {y[0]:.2f} to {y[-1]:.2f} m")
print(f"Looking for ROI: {startFiber} to {endFiber} m")

# Select ROI - ROBUST VERSION
mask = (y >= startFiber) & (y <= endFiber)
indices_roi = np.where(mask)[0]

if len(indices_roi) == 0:
    print(f"ERROR: No channels found in range {startFiber}-{endFiber} m")
    print("Possible solutions:")
    print(f"1. Check that startFiber ({startFiber}) <= endFiber ({endFiber})")
    print(f"2. Check that startFiber ({startFiber}) >= y minimum ({y[0]:.2f})")
    print(f"3. Check that endFiber ({endFiber}) <= y maximum ({y[-1]:.2f})")

    # Suggest close values
    if startFiber < y[0]:
        print(f"Suggestion: startFiber should be >= {y[0]:.2f}")
    if endFiber > y[-1]:
        print(f"Suggestion: endFiber should be <= {y[-1]:.2f}")

    # Use available ranges if ROI not found
    if startFiber < y[0]:
        startFiber = y[0]
        print(f"Using startFiber = {startFiber:.2f}")
    if endFiber > y[-1]:
        endFiber = y[-1]
        print(f"Using endFiber = {endFiber:.2f}")

    # Recalculate mask
    mask = (y >= startFiber) & (y <= endFiber)
    indices_roi = np.where(mask)[0]

if len(indices_roi) > 0:
    idx_roi_start = indices_roi[0]
    idx_roi_end = indices_roi[-1]

    print(f"ROI found: channels {idx_roi_start} to {idx_roi_end}")
    print(f"Actual range in y: {y[idx_roi_start]:.2f} to {y[idx_roi_end]:.2f} m")

    X_roi = X[idx_roi_start:idx_roi_end+1, :]
    y_roi = y[idx_roi_start:idx_roi_end+1]

    # Apply scale
    y_relative = y_roi - y_roi[0]
    y_final_axis = y_relative * scale_factor

    # Calculate time limits for visualization
    t_min_vis = max(t[0], 5.0)
    t_max_vis = min(t[-1], 10.0)

    # Create figure
    plt.figure(figsize=(12, 8))

    # Check that data is not empty
    if X_roi.size > 0:
        # Create the image
        im = plt.imshow(X_roi.T,
                        aspect='auto',
                        origin='lower',  # Similar to 'axis xy' in MATLAB
                        extent=[y_final_axis[0], y_final_axis[-1], t[0], t[-1]],
                        cmap='jet',
                        vmin=phase_min,
                        vmax=phase_max)

        plt.xlabel('Helical Distance (m)')
        plt.ylabel('Time (s)')
        plt.title(f'ROI from {startFiber:.1f}-{endFiber:.1f} m')
        plt.colorbar(im, label='Phase (rad)')

        # Adjust Y-axis limits if possible
        if t_min_vis < t_max_vis:
            plt.ylim([t_min_vis, t_max_vis])
            print(f"Showing time from {t_min_vis:.1f} to {t_max_vis:.1f} seconds")
        else:
            plt.ylim([t[0], t[-1]])
            print(f"Showing all time: {t[0]:.1f} to {t[-1]:.1f} seconds")

        plt.tight_layout()
        plt.show()
    else:
        print("ERROR: X_roi is empty, cannot create figure")
else:
    print(f"ERROR: No channels found after adjusting parameters")
    print(f"Available y range: {y[0]:.2f} to {y[-1]:.2f} m")

Processing region of interest...
Range of y: 0.00 to 295.12 m
Looking for ROI: 30 to 200 m
ROI found: channels 57 to 376
Actual range in y: 30.31 to 199.94 m
Showing all time: 0.0 to 0.0 seconds


In [19]:
'''
========================================================================
9. 2D FFT (K-F ANALYSIS) (Figure 5)
========================================================================
'''
print('Performing 2D FFT analysis...')

if X_roi is not None and X_roi.shape[0] > 0 and X_roi.shape[1] > 0:
    F_dB, f_vec, k_vec = fft_2d_analysis(X_roi, fs, dx, scale_factor)

    plt.figure(figsize=(12, 8))
    im = plt.imshow(F_dB, aspect='auto',
                    extent=[f_vec[0], f_vec[-1], k_vec[0], k_vec[-1]],
                    cmap='jet',
                    vmin=np.percentile(F_dB, 30),
                    vmax=np.percentile(F_dB, 99.9))
    plt.xlabel('Frequency (Hz)')
    plt.ylabel('Wavenumber (cycles/m)')
    plt.title(f'2D FFT - Region {startFiber}-{endFiber} m')
    plt.colorbar(im, label='Spectral Amplitude (dB)')
    plt.xlim([-min(5000, fs/2), min(5000, fs/2)])
    plt.ylim([-25, 25])
    plt.savefig('figure5_fft2d.png', dpi=300)
    plt.show()

Performing 2D FFT analysis...


In [20]:
'''
========================================================================
10. POWER AND RMS CALCULATION
========================================================================
'''
print('Calculating power and RMS...')

# Select ROI for power analysis
idx_power_start = np.where(y >= startFiberProfile)[0]
idx_power_end = np.where(y <= endFiberProfile)[0]

if len(idx_power_start) > 0 and len(idx_power_end) > 0:
    idx_power_start = idx_power_start[0]
    idx_power_end = idx_power_end[-1]

    if idx_power_end > idx_power_start:
        X_power_roi = X[idx_power_start:idx_power_end+1, :]

        # Remove common mode
        mean_power_trace = np.mean(X_power_roi, axis=0)
        X_power_filtered = X_power_roi - mean_power_trace

        # Calculate metrics
        avg_power = np.mean(X_power_filtered**2)
        total_rms = np.sqrt(avg_power)

        print('\n--- Energy Analysis ---')
        print(f' -> ROI: {startFiberProfile}-{endFiberProfile} m')
        print(f' -> Average Power: {avg_power:.6f} rad²')
        print(f' -> Total RMS: {total_rms:.6f} rad')

Calculating power and RMS...

--- Energy Analysis ---
 -> ROI: 120-150 m
 -> Average Power: 0.020822 rad²
 -> Total RMS: 0.144297 rad
