In [1]:
# Test computing PSDs using PyTorch

In [2]:
from torch import hann_window, linspace, stft, tensor, cos
from torch import abs as abs_torch
from scipy.signal import ShortTimeFFT
from scipy.signal.windows import hann
from numpy import pi, sqrt, sum
from numpy import abs as abs_np

from utils_torch import get_window_enbw

In [3]:
numpts = 1000
starttime = 0.0
sampling_rate = 1000.0
freq = 50.0
num_fft = 100
overlap = 0.5

endtime = starttime + (numpts - 1) / sampling_rate

In [4]:
timeax = linspace(starttime, endtime, numpts)
signal = cos(2 * pi * freq * timeax)

In [5]:
stft_torch = stft(signal, num_fft, window = hann_window(num_fft), return_complex = True)

In [6]:
amp_torch = 2 * abs_torch(stft_torch) / num_fft
power_torch = amp_torch ** 2 / 2
enbw = get_window_enbw(hann_window(num_fft), sampling_rate)
psd_torch = power_torch / enbw

print(amp_torch.max())
print(power_torch.max())
print(psd_torch.max())

tensor(0.5000)
tensor(0.1250)
tensor(0.0083)


In [7]:
hop = int(num_fft * (1 - overlap))
STFT = ShortTimeFFT(hann(num_fft), hop, sampling_rate, scale_to = "psd")
stft_scipy = STFT.stft(signal)
amp_scipy = abs_np(stft_scipy)
psd_scipy = STFT.spectrogram(signal)

In [8]:
# print(amp_scipy.max() ** 2)
# print(psd_scipy.max())

window = hann(num_fft)
window_norm = STFT.win
print(window_norm[1] ** 2 / window[1] ** 2)

print(1 / sum(window ** 2) / sampling_rate)
print(sum(window ** 2) * sampling_rate / sum(window) ** 2)
print(enbw)


2.6936026936026937e-05
2.6936026936026934e-05
15.151515151515152
tensor(15.0000)


In [None]:
10 * 