In [1]:
import fft_interp
import torch
import numpy as np
import plotly.express as px
import plotly.graph_objects as go
from scipy.signal import hilbert, butter, filtfilt


In [None]:
if torch.backends.mps.is_available():
    mps_device = torch.device("mps")
    x = torch.ones(1, device=mps_device)
    print (x)
else:
    print ("MPS device not found.")

In [3]:
frames = fft_interp.get_frames("videos/rawglasssmall.avi")



In [None]:
frames.shape

In [5]:
# frames = frames[:, :50, :50]

In [6]:
def hilbert_transform_1d_torch(data_torch, axis: int = -1) -> torch.Tensor:
    """
    Compute the 1D Hilbert transform of a 3D real array along the specified axis
    using PyTorch's FFT operations.

    Parameters
    ----------
    data_np : np.ndarray
        A 3D real-valued NumPy array (e.g. shape (X, Y, Z)).
    axis : int
        The axis along which to compute the 1D Hilbert transform.

    Returns
    -------
    hilbert : np.ndarray
        A PyTorch tensor containing the Hilbert transform of data_np
        along the specified axis. The shape matches data_np, but the dtype
        is float (matching the imaginary result of the inverse FFT).
    """

    # Convert the NumPy array to a PyTorch tensor (float or double)
    # We'll assume float32 here; adjust as needed
    # data_torch = torch.from_numpy(data_np).to(torch.float32)

    # FFT along the chosen axis
    data_fft = torch.fft.fft(data_torch, dim=axis)

    # Prepare the frequency-domain multiplier for the Hilbert transform
    n = data_torch.size(axis)

    # Create an empty complex filter (shape = n), initially zeros
    hilb_filter = torch.zeros(n, dtype=torch.complex64, device=data_fft.device)

    # Handle even/odd length along 'axis'
    #   - DC component (k=0) and (if even length) Nyquist freq (k=n/2) remain 0
    #   - For 1 <= k < n/2: multiply by -j
    #   - For n/2 < k < n: multiply by +j
    if n % 2 == 0:
        # Even number of points
        #  - Positive freqs are indices [1 ... n/2 - 1]
        #  - Nyquist freq is index n/2
        hilb_filter[1 : (n // 2)] = -1j
        hilb_filter[(n // 2 + 1) : ] = 1j
    else:
        # Odd number of points
        #  - Positive freqs are indices [1 ... (n-1)//2]
        #  - Negative freqs are indices [(n+1)//2 ... n-1]
        half_n = (n + 1) // 2
        hilb_filter[1 : half_n] = -1j
        hilb_filter[half_n : ] = 1j

    # Reshape the filter so it can broadcast along 'axis' in a 3D tensor
    # Build a shape of [1,1,1] and replace the dimension at 'axis' with n
    shape = [1, 1, 1]
    shape[axis] = n
    hilb_filter = hilb_filter.reshape(shape)

    # Apply the Hilbert filter in the frequency domain
    data_fft_filtered = data_fft * hilb_filter

    # Inverse FFT to get the Hilbert transform in time/space domain
    # The result is, in general, a complex tensor whose imaginary part
    # corresponds to the Hilbert transform of the original data.
    data_ifft = torch.fft.ifft(data_fft_filtered, dim=axis)

    return data_ifft

In [None]:
gpu_frames = torch.tensor(frames, dtype=torch.float32)
gpu_frames = gpu_frames.to(mps_device)
hilbert_ed = hilbert_transform_1d_torch(gpu_frames, axis=0)
hilbert_ed = hilbert_ed.cpu().numpy()
# plot for one pixel to debug

hilbert_ed = np.abs(hilbert_ed)

b, a = butter(2, 0.01, btype='lowpass')

filtered_envelope = filtfilt(b, a, hilbert_ed, axis=0)
height_map = np.argmax(filtered_envelope, axis=0)

SCAN_SPEED = 0.25  # Microns per second
FPS = 30

microns_per_frame = SCAN_SPEED * 1 / FPS

height_map = height_map *microns_per_frame

In [None]:
fft_interp.plot_height_map(height_map)

In [None]:
fig = px.line(y=hilbert_ed[:, 0, 0])
fig.add_scatter(y=frames[:, 0, 0])
fig.add_scatter(y=filtered_envelope[:, 0, 0])
fig.show()