Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions doc/changelog.d/287.added.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
feat: improve stft plot computation time
52 changes: 33 additions & 19 deletions examples/004_isolate_orders.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,46 +75,60 @@
# more control over what you are displaying.
# While you could use the ``Stft.plot()`` method, the custom function
# defined here restricts the frequency range of the plot.
def plot_stft(stft_class, vmax):
out = stft_class.get_output_as_nparray()

# Extract first half of the STFT (second half is symmetrical)
half_nfft = int(out.shape[0] / 2) + 1
magnitude = stft_class.get_stft_magnitude_as_nparray()
def plot_stft(
stft: Stft,
SPLmax: float,
title: str = "STFT",
maximum_frequency: float = MAX_FREQUENCY_PLOT_STFT,
) -> None:
"""Plot a short-term Fourier transform (STFT) into a figure window.

Parameters
----------
stft: Stft
Object containing the STFT.
SPLmax: float
Maximum value (here in dB SPL) for the colormap.
title: str, default: "STFT"
Title of the figure.
maximum_frequency: float, default: MAX_FREQUENCY_PLOT_STFT
Maximum frequency in Hz to display.
"""
magnitude = stft.get_stft_magnitude_as_nparray()

# Only extract the first half of the STFT, as it is symmetrical
half_nfft = int(magnitude.shape[0] / 2) + 1

# Voluntarily ignore a numpy warning
np.seterr(divide="ignore")
magnitude = 20 * np.log10(magnitude[0:half_nfft, :])
np.seterr(divide="warn")

# Obtain sampling frequency, time steps, and number of time samples
fs = 1.0 / (
stft_class.signal.time_freq_support.time_frequencies.data[1]
- stft_class.signal.time_freq_support.time_frequencies.data[0]
)
time_step = np.floor(stft_class.fft_size * (1.0 - stft_class.window_overlap) + 0.5) / fs
num_time_index = len(stft_class.get_output().get_available_ids_for_label("time"))
time_data = stft.signal.time_freq_support.time_frequencies.data
time_step = time_data[1] - time_data[0]
fs = 1.0 / time_step
num_time_index = len(stft.get_output().get_available_ids_for_label("time"))

# Define boundaries of the plot
extent = [0, time_step * num_time_index, 0.0, fs / 2.0]

# Plot
plt.figure()
plt.imshow(
magnitude,
origin="lower",
aspect="auto",
cmap="jet",
extent=extent,
vmin=vmax - 70.0,
vmax=vmax,
vmax=SPLmax,
vmin=SPLmax - 70.0,
)
plt.colorbar(label="Amplitude (dB SPL)")
plt.colorbar(label="Magnitude (dB SPL)")
plt.ylabel("Frequency (Hz)")
plt.xlabel("Time (s)")
plt.ylim(
[0.0, MAX_FREQUENCY_PLOT_STFT]
) # Change the value of MAX_FREQUENCY_PLOT_STFT if needed
plt.title("STFT")
plt.ylim([0.0, maximum_frequency]) # Change the value of MAX_FREQUENCY_PLOT_STFT if needed
plt.title(title)
plt.show()


Expand Down
34 changes: 18 additions & 16 deletions examples/005_xtract_feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,38 +79,40 @@
# more control over what you are displaying.
# While you could use the ``Stft.plot()`` method, the custom function
# defined here restricts the frequency range of the plot.
def plot_stft(stft_class, SPLmax, title="STFT", maximum_frequency=MAX_FREQUENCY_PLOT_STFT):
def plot_stft(
stft: Stft,
SPLmax: float,
title: str = "STFT",
maximum_frequency: float = MAX_FREQUENCY_PLOT_STFT,
) -> None:
"""Plot a short-term Fourier transform (STFT) into a figure window.

Parameters
----------
stft_class: Stft
stft: Stft
Object containing the STFT.
SPLmax: float
Maximum value (here in dB SPL) for the colormap.
title: str
title: str, default: "STFT"
Title of the figure.
maximum_frequency: float
maximum_frequency: float, default: MAX_FREQUENCY_PLOT_STFT
Maximum frequency in Hz to display.
"""
out = stft_class.get_output_as_nparray()
magnitude = stft.get_stft_magnitude_as_nparray()

# Extract first half of the STFT (second half is symmetrical)
half_nfft = int(out.shape[0] / 2) + 1
magnitude = stft_class.get_stft_magnitude_as_nparray()
# Only extract the first half of the STFT, as it is symmetrical
half_nfft = int(magnitude.shape[0] / 2) + 1

# Voluntarily ignore a numpy warning
np.seterr(divide="ignore")
magnitude = 20 * np.log10(magnitude[0:half_nfft, :])
np.seterr(divide="warn")

# Obtain sampling frequency, time steps, and number of time samples
fs = 1.0 / (
stft_class.signal.time_freq_support.time_frequencies.data[1]
- stft_class.signal.time_freq_support.time_frequencies.data[0]
)
time_step = np.floor(stft_class.fft_size * (1.0 - stft_class.window_overlap) + 0.5) / fs
num_time_index = len(stft_class.get_output().get_available_ids_for_label("time"))
time_data = stft.signal.time_freq_support.time_frequencies.data
time_step = time_data[1] - time_data[0]
fs = 1.0 / time_step
num_time_index = len(stft.get_output().get_available_ids_for_label("time"))

# Define boundaries of the plot
extent = [0, time_step * num_time_index, 0.0, fs / 2.0]
Expand All @@ -124,7 +126,7 @@ def plot_stft(stft_class, SPLmax, title="STFT", maximum_frequency=MAX_FREQUENCY_
cmap="jet",
extent=extent,
vmax=SPLmax,
vmin=(SPLmax - 70.0),
vmin=SPLmax - 70.0,
)
plt.colorbar(label="Magnitude (dB SPL)")
plt.ylabel("Frequency (Hz)")
Expand Down Expand Up @@ -328,7 +330,7 @@ def plot_stft(stft_class, SPLmax, title="STFT", maximum_frequency=MAX_FREQUENCY_
# Compute and plot the STFT
stft_original.signal = time_domain_signal
stft_original.process()
plot_stft(stft_class=stft_original, SPLmax=max_stft, title=f"STFT for signal {signal_name}")
plot_stft(stft=stft_original, SPLmax=max_stft, title=f"STFT for signal {signal_name}")

# Use Xtract with the loaded signal
xtract.input_signal = time_domain_signal
Expand Down
30 changes: 13 additions & 17 deletions src/ansys/sound/core/spectrogram_processing/stft.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,19 +190,18 @@ def get_output_as_nparray(self) -> np.ndarray:
"""
output = self.get_output()

num_time_index = len(output.get_available_ids_for_label("time"))
time_indexes = output.get_available_ids_for_label("time")
Ntime = len(time_indexes)
Nfft = output.get_field({"complex": 0, "time": 0, "channel_number": 0}).data.shape[0]

f1 = output.get_field({"complex": 0, "time": 0, "channel_number": 0})
f2 = output.get_field({"complex": 1, "time": 0, "channel_number": 0})
# Pre-allocate memory for the output array.
out_as_np_array = np.empty((Ntime, Nfft), dtype=np.complex128)

out_as_np_array = f1.data + 1j * f2.data
for i in range(1, num_time_index):
for i in time_indexes:
f1 = output.get_field({"complex": 0, "time": i, "channel_number": 0})
f2 = output.get_field({"complex": 1, "time": i, "channel_number": 0})
tmp_arr = f1.data + 1j * f2.data
out_as_np_array = np.vstack((out_as_np_array, tmp_arr))
out_as_np_array[i] = f1.data + 1j * f2.data

# return out_as_np_array
return np.transpose(out_as_np_array)

def get_stft_magnitude_as_nparray(self) -> np.ndarray:
Expand Down Expand Up @@ -232,22 +231,19 @@ def plot(self):

This method plots the STFT amplitude and the associated phase.
"""
out = self.get_output_as_nparray()

# Extracting first half of the STFT (second half is symmetrical)
half_nfft = int(np.shape(out)[0] / 2) + 1
magnitude = self.get_stft_magnitude_as_nparray()

# Only extract the first half of the STFT, as it is symmetrical
half_nfft = int(np.shape(magnitude)[0] / 2) + 1

np.seterr(divide="ignore")
magnitude = 20 * np.log10(magnitude[0:half_nfft, :])
np.seterr(divide="warn")
phase = self.get_stft_phase_as_nparray()
phase = phase[0:half_nfft, :]
fs = 1.0 / (
self.signal.time_freq_support.time_frequencies.data[1]
- self.signal.time_freq_support.time_frequencies.data[0]
)
time_step = np.floor(self.fft_size * (1.0 - self.window_overlap) + 0.5) / fs
time_data = self.signal.time_freq_support.time_frequencies.data
time_step = time_data[1] - time_data[0]
fs = 1.0 / time_step
num_time_index = len(self.get_output().get_available_ids_for_label("time"))

# Boundaries of the plot
Expand Down